From 2c3b9cbce80b809d57f4ad3826c3f65dd5b82c91 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Wed, 4 Jun 2025 17:03:56 +0100 Subject: [PATCH 01/27] add hf inference providers support --- .../pydantic_ai/models/huggingface.py | 432 ++++++++++++++++++ .../pydantic_ai/providers/__init__.py | 4 + .../pydantic_ai/providers/huggingface.py | 72 +++ pydantic_ai_slim/pyproject.toml | 2 + pyproject.toml | 2 +- uv.lock | 34 +- 6 files changed, 539 insertions(+), 7 deletions(-) create mode 100644 pydantic_ai_slim/pydantic_ai/models/huggingface.py create mode 100644 pydantic_ai_slim/pydantic_ai/providers/huggingface.py diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py new file mode 100644 index 0000000000..c34d741a37 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -0,0 +1,432 @@ +from __future__ import annotations as _annotations + +import base64 +from collections.abc import AsyncIterable, AsyncIterator +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Literal, cast, overload + +from typing_extensions import assert_never + +from pydantic_ai.providers import Provider, infer_provider + +from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage +from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc +from ..messages import ( + AudioUrl, + BinaryContent, + DocumentUrl, + ImageUrl, + ModelMessage, + ModelRequest, + ModelResponse, + ModelResponsePart, + ModelResponseStreamEvent, + RetryPromptPart, + SystemPromptPart, + TextPart, + ToolCallPart, + ToolReturnPart, + UserPromptPart, + VideoUrl, +) +from ..settings import ModelSettings +from ..tools import ToolDefinition +from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests + +try: + import aiohttp + from huggingface_hub import ( + AsyncInferenceClient, + ChatCompletionInputMessage, + ChatCompletionInputMessageChunk, + ChatCompletionInputTool, + ChatCompletionInputToolCall, + ChatCompletionInputURL, + ChatCompletionOutput, + ChatCompletionOutputMessage, + ChatCompletionStreamOutput, + InferenceTimeoutError, + ) + +except ImportError as _import_error: + raise ImportError( + 'Please install `huggingface_hub` to use Hugging Face Inference Providers, ' + 'you can use the `huggingface` optional group — `pip install "pydantic-ai-slim[huggingface]"`' + ) from _import_error + +__all__ = ( + 'HuggingFaceModel', + 'HuggingFaceModelSettings', +) + + +HFSystemPromptRole = Literal['system', 'user'] + + +class HuggingFaceModelSettings(ModelSettings, total=False): + """Settings used for a Hugging Face model request. + + ALL FIELDS MUST BE `huggingface_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. + """ + + # This class is a placeholder for any future huggingface-specific settings + + +@dataclass(init=False) +class HuggingFaceModel(Model): + """A model that uses Hugging Face Inference Providers. + + Internally, this uses the [HF Python client](https://github.com/huggingface/huggingface_hub) to interact with the API. + + Apart from `__init__`, all methods are private or match those of the base class. + """ + + client: AsyncInferenceClient = field(repr=False) + + _model_name: str = field(repr=False) + _system: str = field(default='huggingface', repr=False) + + def __init__( + self, + model_name: str, + *, + provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface', + ): + """Initialize a Hugging Face model. + + Args: + model_name: The name of the Model to use. You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending). + provider: The provider to use for Hugging Face Inference Providers. Can be either the string 'huggingface' or an + instance of `Provider[AsyncInferenceClient]`. If not provided, the other parameters will be used. + """ + self._model_name = model_name + self._provider = provider + if isinstance(provider, str): + provider = infer_provider(provider) + self.client = provider.client + + async def request( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> ModelResponse: + check_allow_model_requests() + response = await self._completions_create( + messages, False, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters + ) + model_response = self._process_response(response) + model_response.usage.requests = 1 + return model_response + + @asynccontextmanager + async def request_stream( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> AsyncIterator[StreamedResponse]: + check_allow_model_requests() + response = await self._completions_create( + messages, True, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters + ) + yield await self._process_streamed_response(response) + + @property + def model_name(self) -> str: + """The model name.""" + return self._model_name + + @property + def system(self) -> str: + """The system / model provider.""" + return self._system + + @overload + async def _completions_create( + self, + messages: list[ModelMessage], + stream: Literal[True], + model_settings: HuggingFaceModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> AsyncIterable[ChatCompletionStreamOutput]: ... + + @overload + async def _completions_create( + self, + messages: list[ModelMessage], + stream: Literal[False], + model_settings: HuggingFaceModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> ChatCompletionOutput: ... + + async def _completions_create( + self, + messages: list[ModelMessage], + stream: bool, + model_settings: HuggingFaceModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> ChatCompletionOutput | AsyncIterable[ChatCompletionStreamOutput]: + tools = self._get_tools(model_request_parameters) + + if not tools: + tool_choice: Literal['none', 'required', 'auto'] | None = None + elif not model_request_parameters.allow_text_output: + tool_choice = 'required' + else: + tool_choice = 'auto' + + hf_messages = await self._map_messages(messages) + + try: + return await self.client.chat.completions.create( # type: ignore + model=self._model_name, + messages=hf_messages, # type: ignore + tools=tools, + tool_choice=tool_choice or None, + stream=stream, + stop=model_settings.get('stop_sequences', None), + temperature=model_settings.get('temperature', None), + top_p=model_settings.get('top_p', None), + seed=model_settings.get('seed', None), + presence_penalty=model_settings.get('presence_penalty', None), + frequency_penalty=model_settings.get('frequency_penalty', None), + logit_bias=model_settings.get('logit_bias', None), # type: ignore + logprobs=model_settings.get('logprobs', None), + top_logprobs=model_settings.get('top_logprobs', None), + extra_body=model_settings.get('extra_body'), # type: ignore + ) + except (InferenceTimeoutError, aiohttp.ClientResponseError) as e: + if isinstance(e, aiohttp.ClientResponseError): + raise ModelHTTPError( + status_code=e.status, + model_name=self.model_name, + body=e.response_error_payload, # type: ignore + ) from e + raise # pragma: lax no cover + + def _process_response(self, response: ChatCompletionOutput) -> ModelResponse: + """Process a non-streamed response, and prepare a message to return.""" + if response.created: + timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc) + else: + timestamp = _now_utc() + + choice = response.choices[0] + items: list[ModelResponsePart] = [] + + if choice.message.content is not None: + items.append(TextPart(choice.message.content)) + if choice.message.tool_calls is not None: + for c in choice.message.tool_calls: + items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id)) + return ModelResponse( + items, + usage=_map_usage(response), + model_name=response.model, + timestamp=timestamp, + vendor_id=response.id, + ) + + async def _process_streamed_response(self, response: AsyncIterable[ChatCompletionStreamOutput]) -> StreamedResponse: + """Process a streamed response, and prepare a streaming response to return.""" + peekable_response = _utils.PeekableAsyncStream(response) + first_chunk = await peekable_response.peek() + if isinstance(first_chunk, _utils.Unset): + raise UnexpectedModelBehavior( # pragma: no cover + 'Streamed response ended without content or tool calls' + ) + + return HuggingFaceStreamedResponse( + _model_name=self._model_name, + _response=peekable_response, + _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc), + ) + + def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ChatCompletionInputTool]: + tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] + if model_request_parameters.output_tools: + tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] + return tools + + async def _map_messages( + self, messages: list[ModelMessage] + ) -> list[ChatCompletionInputMessage | ChatCompletionOutputMessage]: + """Just maps a `pydantic_ai.Message` to a `huggingface_hub.ChatCompletionInputMessage`.""" + hf_messages: list[ChatCompletionInputMessage | ChatCompletionOutputMessage] = [] + for message in messages: + if isinstance(message, ModelRequest): + async for item in self._map_user_message(message): + hf_messages.append(item) + elif isinstance(message, ModelResponse): + texts: list[str] = [] + tool_calls: list[ChatCompletionInputToolCall] = [] + for item in message.parts: + if isinstance(item, TextPart): + texts.append(item.content) + elif isinstance(item, ToolCallPart): + tool_calls.append(self._map_tool_call(item)) + else: + assert_never(item) + message_param = ChatCompletionInputMessage(role='assistant') # type: ignore + if texts: + # Note: model responses from this model should only have one text item, so the following + # shouldn't merge multiple texts into one unless you switch models between runs: + message_param['content'] = '\n\n'.join(texts) + if tool_calls: + message_param['tool_calls'] = tool_calls + hf_messages.append(message_param) + else: + assert_never(message) + if instructions := self._get_instructions(messages): + hf_messages.insert(0, ChatCompletionInputMessage(content=instructions, role='system')) # type: ignore + return hf_messages + + @staticmethod + def _map_tool_call(t: ToolCallPart) -> ChatCompletionInputToolCall: + return ChatCompletionInputToolCall.parse_obj_as_instance( # type: ignore + { + 'id': _guard_tool_call_id(t=t), + 'type': 'function', + 'function': { + 'name': t.tool_name, + 'arguments': t.args_as_json_str(), + }, + } + ) + + @staticmethod + def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool: + tool_param: ChatCompletionInputTool = ChatCompletionInputTool.parse_obj_as_instance( # type: ignore + { + 'type': 'function', + 'function': { + 'name': f.name, + 'description': f.description, + 'parameters': f.parameters_json_schema, + }, + } + ) + if f.strict: + tool_param['function']['strict'] = f.strict + return tool_param + + async def _map_user_message( + self, message: ModelRequest + ) -> AsyncIterable[ChatCompletionInputMessage | ChatCompletionOutputMessage]: + for part in message.parts: + if isinstance(part, SystemPromptPart): + yield ChatCompletionInputMessage.parse_obj_as_instance({'role': 'system', 'content': part.content}) # type: ignore + elif isinstance(part, UserPromptPart): + yield await self._map_user_prompt(part) + elif isinstance(part, ToolReturnPart): + yield ChatCompletionOutputMessage.parse_obj_as_instance( # type: ignore + { + 'role': 'tool', + 'tool_call_id': _guard_tool_call_id(t=part), + 'content': part.model_response_str(), + } + ) + elif isinstance(part, RetryPromptPart): + if part.tool_name is None: + yield ChatCompletionInputMessage.parse_obj_as_instance( # type: ignore + {'role': 'user', 'content': part.model_response()} + ) + else: + yield ChatCompletionInputMessage.parse_obj_as_instance( # type: ignore + { + 'role': 'tool', + 'tool_call_id': _guard_tool_call_id(t=part), + 'content': part.model_response(), + } + ) + else: + assert_never(part) + + @staticmethod + async def _map_user_prompt(part: UserPromptPart) -> ChatCompletionInputMessage: + content: str | list[ChatCompletionInputMessage] + if isinstance(part.content, str): + content = part.content + else: + content = [] + for item in part.content: + if isinstance(item, str): + content.append(ChatCompletionInputMessageChunk(type='text', text=item)) # type: ignore + elif isinstance(item, ImageUrl): + url = ChatCompletionInputURL(url=item.url) # type: ignore + content.append(ChatCompletionInputMessageChunk(type='image_url', image_url=url)) # type: ignore + elif isinstance(item, BinaryContent): + base64_encoded = base64.b64encode(item.data).decode('utf-8') + if item.is_image: + url = ChatCompletionInputURL(url=f'data:{item.media_type};base64,{base64_encoded}') # type: ignore + content.append(ChatCompletionInputMessageChunk(type='image_url', image_url=url)) # type: ignore + else: # pragma: no cover + raise RuntimeError(f'Unsupported binary content type: {item.media_type}') + elif isinstance(item, AudioUrl): + raise NotImplementedError('AudioUrl is not supported for Hugging Face') + elif isinstance(item, DocumentUrl): + raise NotImplementedError('DocumentUrl is not supported for Hugging Face') + elif isinstance(item, VideoUrl): # pragma: no cover + raise NotImplementedError('VideoUrl is not supported for Hugging Face') + else: + assert_never(item) + return ChatCompletionInputMessage(role='user', content=content) # type: ignore + + +@dataclass +class HuggingFaceStreamedResponse(StreamedResponse): + """Implementation of `StreamedResponse` for Hugging Face models.""" + + _model_name: str + _response: AsyncIterable[ChatCompletionStreamOutput] + _timestamp: datetime + + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: + async for chunk in self._response: + self._usage += _map_usage(chunk) + + try: + choice = chunk.choices[0] + except IndexError: + continue + + # Handle the text part of the response + content = choice.delta.content + if content is not None: + yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content) + + for dtc in choice.delta.tool_calls or []: + maybe_event = self._parts_manager.handle_tool_call_delta( + vendor_part_id=dtc.index, + tool_name=dtc.function.name, + args=dtc.function.arguments, + tool_call_id=dtc.id, + ) + if maybe_event is not None: + yield maybe_event + + @property + def model_name(self) -> str: + """Get the model name of the response.""" + return self._model_name + + @property + def timestamp(self) -> datetime: + """Get the timestamp of the response.""" + return self._timestamp + + +def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> usage.Usage: + response_usage = response.usage + if response_usage is None: + return usage.Usage() + + return usage.Usage( + request_tokens=response_usage.prompt_tokens, + response_tokens=response_usage.completion_tokens, + total_tokens=response_usage.total_tokens, + details=None, + ) diff --git a/pydantic_ai_slim/pydantic_ai/providers/__init__.py b/pydantic_ai_slim/pydantic_ai/providers/__init__.py index 379bbbc5da..dedeb6d8f9 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/providers/__init__.py @@ -89,5 +89,9 @@ def infer_provider(provider: str) -> Provider[Any]: from .cohere import CohereProvider return CohereProvider() + elif provider == 'huggingface': + from .huggingface import HuggingFaceProvider + + return HuggingFaceProvider() else: # pragma: no cover raise ValueError(f'Unknown provider: {provider}') diff --git a/pydantic_ai_slim/pydantic_ai/providers/huggingface.py b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py new file mode 100644 index 0000000000..182a1bc838 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py @@ -0,0 +1,72 @@ +from __future__ import annotations as _annotations + +import os + +from mistralai import httpx + +try: + from huggingface_hub import AsyncInferenceClient +except ImportError as _import_error: # pragma: no cover + raise ImportError( + 'Please install the `huggingface_hub` package to use the HuggingFace provider, ' + "you can use the `huggingface` optional group — `pip install 'pydantic-ai-slim[huggingface]'`" + ) from _import_error + +from . import Provider + + +class HuggingFaceProvider(Provider[AsyncInferenceClient]): + """Provider for HuggingFace API.""" + + @property + def name(self) -> str: + return 'huggingface' + + @property + def base_url(self) -> str: + return self.client.model # type: ignore + + @property + def client(self) -> AsyncInferenceClient: + return self._client + + def __init__( + self, + base_url: str | None = None, + api_key: str | None = None, + hf_client: AsyncInferenceClient | None = None, + http_client: httpx.AsyncClient | None = None, + provider: str | None = None, + ) -> None: + """Create a new Hugging Face provider. + + Args: + base_url: The base url for the Hugging Face requests. If not provided, it will default to the HF Inference API base url. + api_key: The API key to use for authentication, if not provided, the `HF_TOKEN` environment variable + will be used if available. + hf_client: An existing + [`AsyncInferenceClient`](https://huggingface.co/docs/huggingface_hub/v0.29.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient) + client to use. If not provided, a new instance will be created. + http_client: (currently ignored) An existing `httpx.AsyncClient` to use for making HTTP requests. + provider : Name of the provider to use for inference. available providers can be found in the [HF Inference Providers documentation](https://huggingface.co/docs/inference-providers/index#partners). + defaults to "auto", which will select the first available provider for the model, the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. + If `base_url` is passed, then `provider` is not used. + """ + api_key = api_key or os.environ.get('HF_TOKEN') + + if api_key is None: + raise ValueError( + 'Set the `HF_TOKEN` environment variable or pass it via `HuggingFaceProvider(api_key=...)`' + 'to use the HuggingFace provider.' + ) + + if http_client is not None: + raise ValueError('`http_client` is ignored for HuggingFace provider, please use `hf_client` instead') + + if base_url is not None and provider is not None: + raise ValueError('Cannot provide both `base_url` and `provider`') + + if hf_client is None: + self._client = AsyncInferenceClient(api_key=api_key, provider=provider, base_url=base_url) # type: ignore + else: + self._client = hf_client diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 631cc196d0..6188867c2d 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -69,6 +69,7 @@ anthropic = ["anthropic>=0.52.0"] groq = ["groq>=0.15.0"] mistral = ["mistralai>=1.2.5"] bedrock = ["boto3>=1.35.74"] +huggingface = ["huggingface-hub>=0.32.0", "aiohttp"] # Tools duckduckgo = ["duckduckgo-search>=7.0.0"] tavily = ["tavily-python>=0.5.0"] @@ -81,6 +82,7 @@ evals = ["pydantic-evals=={{ version }}"] # A2A a2a = ["fasta2a=={{ version }}"] + [dependency-groups] dev = [ "anyio>=4.5.0", diff --git a/pyproject.toml b/pyproject.toml index 04b206432f..74404b1f75 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ requires-python = ">=3.9" [tool.hatch.metadata.hooks.uv-dynamic-versioning] dependencies = [ - "pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,cli,mcp,evals,a2a]=={{ version }}", + "pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,huggingface,cli,mcp,evals,a2a]=={{ version }}", ] [tool.hatch.metadata.hooks.uv-dynamic-versioning.optional-dependencies] diff --git a/uv.lock b/uv.lock index 1da96b32ed..3c503af2f4 100644 --- a/uv.lock +++ b/uv.lock @@ -1312,6 +1312,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761", size = 58259, upload-time = "2022-09-25T15:39:59.68Z" }, ] +[[package]] +name = "hf-xet" +version = "1.1.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/75/dc/dc091aeeb671e71cbec30e84963f9c0202c17337b24b0a800e7d205543e8/hf_xet-1.1.3.tar.gz", hash = "sha256:a5f09b1dd24e6ff6bcedb4b0ddab2d81824098bb002cf8b4ffa780545fa348c3", size = 488127, upload-time = "2025-06-04T00:47:27.456Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9b/1f/bc01a4c0894973adebbcd4aa338a06815c76333ebb3921d94dcbd40dae6a/hf_xet-1.1.3-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:c3b508b5f583a75641aebf732853deb058953370ce8184f5dabc49f803b0819b", size = 2256929, upload-time = "2025-06-04T00:47:21.206Z" }, + { url = "https://files.pythonhosted.org/packages/78/07/6ef50851b5c6b45b77a6e018fa299c69a2db3b8bbd0d5af594c0238b1ceb/hf_xet-1.1.3-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:b788a61977fbe6b5186e66239e2a329a3f0b7e7ff50dad38984c0c74f44aeca1", size = 2153719, upload-time = "2025-06-04T00:47:19.302Z" }, + { url = "https://files.pythonhosted.org/packages/52/48/e929e6e3db6e4758c2adf0f2ca2c59287f1b76229d8bdc1a4c9cfc05212e/hf_xet-1.1.3-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd2da210856444a34aad8ada2fc12f70dabed7cc20f37e90754d1d9b43bc0534", size = 4820519, upload-time = "2025-06-04T00:47:17.244Z" }, + { url = "https://files.pythonhosted.org/packages/28/2e/03f89c5014a5aafaa9b150655f811798a317036646623bdaace25f485ae8/hf_xet-1.1.3-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:8203f52827e3df65981984936654a5b390566336956f65765a8aa58c362bb841", size = 4964121, upload-time = "2025-06-04T00:47:15.17Z" }, + { url = "https://files.pythonhosted.org/packages/47/8b/5cd399a92b47d98086f55fc72d69bc9ea5e5c6f27a9ed3e0cdd6be4e58a3/hf_xet-1.1.3-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:30c575a5306f8e6fda37edb866762140a435037365eba7a17ce7bd0bc0216a8b", size = 5283017, upload-time = "2025-06-04T00:47:23.239Z" }, + { url = "https://files.pythonhosted.org/packages/53/e3/2fcec58d2fcfd25ff07feb876f466cfa11f8dcf9d3b742c07fe9dd51ee0a/hf_xet-1.1.3-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:7c1a6aa6abed1f696f8099aa9796ca04c9ee778a58728a115607de9cc4638ff1", size = 4970349, upload-time = "2025-06-04T00:47:25.383Z" }, + { url = "https://files.pythonhosted.org/packages/53/bf/10ca917e335861101017ff46044c90e517b574fbb37219347b83be1952f6/hf_xet-1.1.3-cp37-abi3-win_amd64.whl", hash = "sha256:b578ae5ac9c056296bb0df9d018e597c8dc6390c5266f35b5c44696003cde9f3", size = 2310934, upload-time = "2025-06-04T00:47:29.632Z" }, +] + [[package]] name = "httpcore" version = "1.0.7" @@ -1351,20 +1366,21 @@ wheels = [ [[package]] name = "huggingface-hub" -version = "0.29.1" +version = "0.32.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, { name = "fsspec" }, + { name = "hf-xet", marker = "platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" }, { name = "packaging" }, { name = "pyyaml" }, { name = "requests" }, { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/22/37/797d6476f13e5ef6af5fc48a5d641d32b39c37e166ccf40c3714c5854a85/huggingface_hub-0.29.1.tar.gz", hash = "sha256:9524eae42077b8ff4fc459ceb7a514eca1c1232b775276b009709fe2a084f250", size = 389776, upload-time = "2025-02-20T09:24:59.839Z" } +sdist = { url = "https://files.pythonhosted.org/packages/60/c8/4f7d270285c46324fd66f62159eb16739aa5696f422dba57678a8c6b78e9/huggingface_hub-0.32.4.tar.gz", hash = "sha256:f61d45cd338736f59fb0e97550b74c24ee771bcc92c05ae0766b9116abe720be", size = 424494, upload-time = "2025-06-03T09:59:46.105Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ae/05/75b90de9093de0aadafc868bb2fa7c57651fd8f45384adf39bd77f63980d/huggingface_hub-0.29.1-py3-none-any.whl", hash = "sha256:352f69caf16566c7b6de84b54a822f6238e17ddd8ae3da4f8f2272aea5b198d5", size = 468049, upload-time = "2025-02-20T09:24:57.962Z" }, + { url = "https://files.pythonhosted.org/packages/67/8b/222140f3cfb6f17b0dd8c4b9a0b36bd4ebefe9fb0098ba35d6960abcda0f/huggingface_hub-0.32.4-py3-none-any.whl", hash = "sha256:37abf8826b38d971f60d3625229221c36e53fe58060286db9baf619cfbf39767", size = 512101, upload-time = "2025-06-03T09:59:44.099Z" }, ] [[package]] @@ -2896,7 +2912,7 @@ wheels = [ name = "pydantic-ai" source = { editable = "." } dependencies = [ - { name = "pydantic-ai-slim", extra = ["a2a", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "mcp", "mistral", "openai", "vertexai"] }, + { name = "pydantic-ai-slim", extra = ["a2a", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "vertexai"] }, ] [package.optional-dependencies] @@ -2930,7 +2946,7 @@ lint = [ requires-dist = [ { name = "logfire", marker = "extra == 'logfire'", specifier = ">=3.11.0" }, { name = "pydantic-ai-examples", marker = "extra == 'examples'", editable = "examples" }, - { name = "pydantic-ai-slim", extras = ["a2a", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "mcp", "mistral", "openai", "vertexai"], editable = "pydantic_ai_slim" }, + { name = "pydantic-ai-slim", extras = ["a2a", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "vertexai"], editable = "pydantic_ai_slim" }, ] provides-extras = ["examples", "logfire"] @@ -3029,6 +3045,10 @@ google = [ groq = [ { name = "groq" }, ] +huggingface = [ + { name = "aiohttp" }, + { name = "huggingface-hub" }, +] logfire = [ { name = "logfire" }, ] @@ -3071,6 +3091,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "aiohttp", marker = "extra == 'huggingface'" }, { name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.52.0" }, { name = "argcomplete", marker = "extra == 'cli'", specifier = ">=3.5.0" }, { name = "boto3", marker = "extra == 'bedrock'", specifier = ">=1.35.74" }, @@ -3084,6 +3105,7 @@ requires-dist = [ { name = "griffe", specifier = ">=1.3.2" }, { name = "groq", marker = "extra == 'groq'", specifier = ">=0.15.0" }, { name = "httpx", specifier = ">=0.27" }, + { name = "huggingface-hub", marker = "extra == 'huggingface'", specifier = ">=0.32.0" }, { name = "logfire", marker = "extra == 'logfire'", specifier = ">=3.11.0" }, { name = "mcp", marker = "python_full_version >= '3.10' and extra == 'mcp'", specifier = ">=1.9.0" }, { name = "mistralai", marker = "extra == 'mistral'", specifier = ">=1.2.5" }, @@ -3098,7 +3120,7 @@ requires-dist = [ { name = "tavily-python", marker = "extra == 'tavily'", specifier = ">=0.5.0" }, { name = "typing-inspection", specifier = ">=0.4.0" }, ] -provides-extras = ["a2a", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "logfire", "mcp", "mistral", "openai", "tavily", "vertexai"] +provides-extras = ["a2a", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "tavily", "vertexai"] [package.metadata.requires-dev] dev = [ From 537a657931bd668ee07363a401338f2a7408eb6b Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Wed, 4 Jun 2025 17:22:45 +0100 Subject: [PATCH 02/27] update dependencies --- pydantic_ai_slim/pyproject.toml | 2 +- uv.lock | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 6188867c2d..bdafe2df9f 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -69,7 +69,7 @@ anthropic = ["anthropic>=0.52.0"] groq = ["groq>=0.15.0"] mistral = ["mistralai>=1.2.5"] bedrock = ["boto3>=1.35.74"] -huggingface = ["huggingface-hub>=0.32.0", "aiohttp"] +huggingface = ["huggingface-hub[inference]>=0.32.0"] # Tools duckduckgo = ["duckduckgo-search>=7.0.0"] tavily = ["tavily-python>=0.5.0"] diff --git a/uv.lock b/uv.lock index 3c503af2f4..20af5addd7 100644 --- a/uv.lock +++ b/uv.lock @@ -1383,6 +1383,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/67/8b/222140f3cfb6f17b0dd8c4b9a0b36bd4ebefe9fb0098ba35d6960abcda0f/huggingface_hub-0.32.4-py3-none-any.whl", hash = "sha256:37abf8826b38d971f60d3625229221c36e53fe58060286db9baf619cfbf39767", size = 512101, upload-time = "2025-06-03T09:59:44.099Z" }, ] +[package.optional-dependencies] +inference = [ + { name = "aiohttp" }, +] + [[package]] name = "idna" version = "3.10" @@ -3046,8 +3051,7 @@ groq = [ { name = "groq" }, ] huggingface = [ - { name = "aiohttp" }, - { name = "huggingface-hub" }, + { name = "huggingface-hub", extra = ["inference"] }, ] logfire = [ { name = "logfire" }, @@ -3091,7 +3095,6 @@ dev = [ [package.metadata] requires-dist = [ - { name = "aiohttp", marker = "extra == 'huggingface'" }, { name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.52.0" }, { name = "argcomplete", marker = "extra == 'cli'", specifier = ">=3.5.0" }, { name = "boto3", marker = "extra == 'bedrock'", specifier = ">=1.35.74" }, @@ -3105,7 +3108,7 @@ requires-dist = [ { name = "griffe", specifier = ">=1.3.2" }, { name = "groq", marker = "extra == 'groq'", specifier = ">=0.15.0" }, { name = "httpx", specifier = ">=0.27" }, - { name = "huggingface-hub", marker = "extra == 'huggingface'", specifier = ">=0.32.0" }, + { name = "huggingface-hub", extras = ["inference"], marker = "extra == 'huggingface'", specifier = ">=0.32.0" }, { name = "logfire", marker = "extra == 'logfire'", specifier = ">=3.11.0" }, { name = "mcp", marker = "python_full_version >= '3.10' and extra == 'mcp'", specifier = ">=1.9.0" }, { name = "mistralai", marker = "extra == 'mistral'", specifier = ">=1.2.5" }, From af602a509040e2496e81924b3c038c8c254c16e9 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Wed, 4 Jun 2025 17:23:38 +0100 Subject: [PATCH 03/27] nit --- pydantic_ai_slim/pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index bdafe2df9f..e0fd6c6f3e 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -82,7 +82,6 @@ evals = ["pydantic-evals=={{ version }}"] # A2A a2a = ["fasta2a=={{ version }}"] - [dependency-groups] dev = [ "anyio>=4.5.0", From 1f3f7a21d0b315498564542d7f5867a5d3a4283b Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Wed, 4 Jun 2025 17:24:34 +0100 Subject: [PATCH 04/27] update docstring --- pydantic_ai_slim/pydantic_ai/providers/huggingface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/providers/huggingface.py b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py index 182a1bc838..3d301340ff 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py @@ -16,7 +16,7 @@ class HuggingFaceProvider(Provider[AsyncInferenceClient]): - """Provider for HuggingFace API.""" + """Provider for Hugging Face.""" @property def name(self) -> str: From bea050c69abbb2abfb51ddc4338b088f137ec346 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 9 Jun 2025 13:16:57 +0100 Subject: [PATCH 05/27] add tests --- .../pydantic_ai/models/huggingface.py | 13 +- .../pydantic_ai/providers/huggingface.py | 10 +- tests/conftest.py | 14 + .../test_hf_model_instructions.yaml | 125 ++++ .../test_request_simple_success_with_vcr.yaml | 126 ++++ tests/models/test_huggingface.py | 692 ++++++++++++++++++ tests/providers/test_huggingface.py | 61 ++ 7 files changed, 1034 insertions(+), 7 deletions(-) create mode 100644 tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml create mode 100644 tests/models/cassettes/test_huggingface/test_request_simple_success_with_vcr.yaml create mode 100644 tests/models/test_huggingface.py create mode 100644 tests/providers/test_huggingface.py diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index c34d741a37..1dc1db73ca 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -49,6 +49,7 @@ ChatCompletionStreamOutput, InferenceTimeoutError, ) + from huggingface_hub.errors import HfHubHTTPError except ImportError as _import_error: raise ImportError( @@ -198,13 +199,19 @@ async def _completions_create( top_logprobs=model_settings.get('top_logprobs', None), extra_body=model_settings.get('extra_body'), # type: ignore ) - except (InferenceTimeoutError, aiohttp.ClientResponseError) as e: + except (InferenceTimeoutError, aiohttp.ClientResponseError, HfHubHTTPError) as e: if isinstance(e, aiohttp.ClientResponseError): raise ModelHTTPError( status_code=e.status, model_name=self.model_name, body=e.response_error_payload, # type: ignore ) from e + elif isinstance(e, HfHubHTTPError): + raise ModelHTTPError( + status_code=e.response.status_code, + model_name=self.model_name, + body=e.response.content, + ) from e raise # pragma: lax no cover def _process_response(self, response: ChatCompletionOutput) -> ModelResponse: @@ -401,8 +408,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: for dtc in choice.delta.tool_calls or []: maybe_event = self._parts_manager.handle_tool_call_delta( vendor_part_id=dtc.index, - tool_name=dtc.function.name, - args=dtc.function.arguments, + tool_name=dtc.function and dtc.function.name, # type: ignore + args=dtc.function and dtc.function.arguments, tool_call_id=dtc.id, ) if maybe_event is not None: diff --git a/pydantic_ai_slim/pydantic_ai/providers/huggingface.py b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py index 3d301340ff..e18a60d16c 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py @@ -2,7 +2,9 @@ import os -from mistralai import httpx +from httpx import AsyncClient + +from pydantic_ai.exceptions import UserError try: from huggingface_hub import AsyncInferenceClient @@ -35,13 +37,13 @@ def __init__( base_url: str | None = None, api_key: str | None = None, hf_client: AsyncInferenceClient | None = None, - http_client: httpx.AsyncClient | None = None, + http_client: AsyncClient | None = None, provider: str | None = None, ) -> None: """Create a new Hugging Face provider. Args: - base_url: The base url for the Hugging Face requests. If not provided, it will default to the HF Inference API base url. + base_url: The base url for the Hugging Face requests. api_key: The API key to use for authentication, if not provided, the `HF_TOKEN` environment variable will be used if available. hf_client: An existing @@ -55,7 +57,7 @@ def __init__( api_key = api_key or os.environ.get('HF_TOKEN') if api_key is None: - raise ValueError( + raise UserError( 'Set the `HF_TOKEN` environment variable or pass it via `HuggingFaceProvider(api_key=...)`' 'to use the HuggingFace provider.' ) diff --git a/tests/conftest.py b/tests/conftest.py index 65c104718f..db8459fc89 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -281,6 +281,11 @@ def openrouter_api_key() -> str: return os.getenv('OPENROUTER_API_KEY', 'mock-api-key') +@pytest.fixture(scope='session') +def huggingface_api_key() -> str: + return os.getenv('HF_TOKEN', 'mock-api-key') or os.getenv('HUGGINGFACE_API_KEY', 'mock-api-key') + + @pytest.fixture(scope='session') def bedrock_provider(): try: @@ -309,6 +314,7 @@ def model( groq_api_key: str, co_api_key: str, gemini_api_key: str, + huggingface_api_key: str, bedrock_provider: BedrockProvider, ) -> Model: # pragma: lax no cover try: @@ -346,6 +352,14 @@ def model( from pydantic_ai.models.bedrock import BedrockConverseModel return BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider) + elif request.param == 'huggingface': + from pydantic_ai.models.huggingface import HuggingFaceModel + from pydantic_ai.providers.huggingface import HuggingFaceProvider + + return HuggingFaceModel( + 'Qwen/Qwen2.5-72B-Instruct', + provider=HuggingFaceProvider(provider='nebius', api_key=huggingface_api_key), + ) else: raise ValueError(f'Unknown model: {request.param}') except ImportError: diff --git a/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml b/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml new file mode 100644 index 0000000000..f621f4c4f8 --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml @@ -0,0 +1,125 @@ +interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/Qwen/Qwen2.5-72B-Instruct?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '800' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"320-IoLwHc4XKGzRoHW0ok1gY7tY/NI" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 66e81cefd1b1391042d0e47e + id: Qwen/Qwen2.5-72B-Instruct + inferenceProviderMapping: + featherless-ai: + providerId: Qwen/Qwen2.5-72B-Instruct + status: error + task: conversational + fireworks-ai: + providerId: accounts/fireworks/models/qwen2p5-72b-instruct + status: live + task: conversational + hf-inference: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + hyperbolic: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + nebius: + providerId: Qwen/Qwen2.5-72B-Instruct-fast + status: live + task: conversational + novita: + providerId: qwen/qwen-2.5-72b-instruct + status: live + task: conversational + together: + providerId: Qwen/Qwen2.5-72B-Instruct-Turbo + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '560' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: Paris + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: null + created: 1749227878 + id: chatcmpl-54246cfb4fa046e88a984020c4efab20 + model: Qwen/Qwen2.5-72B-Instruct-fast + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 2 + completion_tokens_details: null + prompt_tokens: 26 + prompt_tokens_details: null + total_tokens: 28 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_huggingface/test_request_simple_success_with_vcr.yaml b/tests/models/cassettes/test_huggingface/test_request_simple_success_with_vcr.yaml new file mode 100644 index 0000000000..c9a3b50f2a --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_request_simple_success_with_vcr.yaml @@ -0,0 +1,126 @@ +interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/Qwen/Qwen2.5-72B-Instruct?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '800' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"320-IoLwHc4XKGzRoHW0ok1gY7tY/NI" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 66e81cefd1b1391042d0e47e + id: Qwen/Qwen2.5-72B-Instruct + inferenceProviderMapping: + featherless-ai: + providerId: Qwen/Qwen2.5-72B-Instruct + status: error + task: conversational + fireworks-ai: + providerId: accounts/fireworks/models/qwen2p5-72b-instruct + status: live + task: conversational + hf-inference: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + hyperbolic: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + nebius: + providerId: Qwen/Qwen2.5-72B-Instruct-fast + status: live + task: conversational + novita: + providerId: qwen/qwen-2.5-72b-instruct + status: live + task: conversational + together: + providerId: Qwen/Qwen2.5-72B-Instruct-Turbo + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '700' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: Hello! It's great to meet you. How can I assist you today? Whether you have questions, need information, + or just want to chat, I'm here to help! + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: null + created: 1749226637 + id: chatcmpl-f5783ce357b4415b8d59dbbf5b3cf9bf + model: Qwen/Qwen2.5-72B-Instruct-fast + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 37 + completion_tokens_details: null + prompt_tokens: 30 + prompt_tokens_details: null + total_tokens: 67 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py new file mode 100644 index 0000000000..378ab675a0 --- /dev/null +++ b/tests/models/test_huggingface.py @@ -0,0 +1,692 @@ +from __future__ import annotations as _annotations + +import json +from collections.abc import Sequence +from dataclasses import dataclass +from datetime import datetime, timezone +from functools import cached_property +from typing import Any, Literal, Union, cast +from unittest.mock import Mock + +import pytest +from huggingface_hub import ( + ChatCompletionStreamOutputChoice, + ChatCompletionStreamOutputDelta, + ChatCompletionStreamOutputDeltaToolCall, + ChatCompletionStreamOutputFunction, + ChatCompletionStreamOutputUsage, +) +from inline_snapshot import snapshot +from typing_extensions import TypedDict + +from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehavior +from pydantic_ai.exceptions import ModelHTTPError +from pydantic_ai.messages import ( + BinaryContent, + ImageUrl, + ModelRequest, + ModelResponse, + RetryPromptPart, + SystemPromptPart, + TextPart, + ToolCallPart, + ToolReturnPart, + UserPromptPart, +) +from pydantic_ai.result import Usage +from pydantic_ai.tools import RunContext + +from ..conftest import IsDatetime, IsNow, raise_if_exception, try_import +from .mock_async_stream import MockAsyncStream + +with try_import() as imports_successful: + from huggingface_hub import ( + AsyncInferenceClient, + ChatCompletionInputMessage, + ChatCompletionOutput, + ChatCompletionOutputComplete, + ChatCompletionOutputFunctionDefinition, + ChatCompletionOutputMessage, + ChatCompletionOutputToolCall, + ChatCompletionOutputUsage, + ChatCompletionStreamOutput, + ) + from huggingface_hub.errors import HfHubHTTPError + + from pydantic_ai.models.huggingface import HuggingFaceModel + from pydantic_ai.providers.huggingface import HuggingFaceProvider + + MockChatCompletion = Union[ChatCompletionOutput, Exception] + MockStreamEvent = Union[ChatCompletionStreamOutput, Exception] + +pytestmark = [ + pytest.mark.skipif(not imports_successful(), reason='huggingface_hub not installed'), + pytest.mark.anyio, +] + + +@dataclass +class MockHuggingFace: + completions: MockChatCompletion | Sequence[MockChatCompletion] | None = None + stream: Sequence[MockStreamEvent] | Sequence[Sequence[MockStreamEvent]] | None = None + index: int = 0 + + @cached_property + def chat(self) -> Any: + completions = type('Completions', (), {'create': self.chat_completions_create}) + return type('Chat', (), {'completions': completions}) + + @classmethod + def create_mock(cls, completions: MockChatCompletion | Sequence[MockChatCompletion]) -> AsyncInferenceClient: + return cast(AsyncInferenceClient, cls(completions=completions)) + + @classmethod + def create_stream_mock( + cls, stream: Sequence[MockStreamEvent] | Sequence[Sequence[MockStreamEvent]] + ) -> AsyncInferenceClient: + return cast(AsyncInferenceClient, cls(stream=stream)) + + async def chat_completions_create( + self, *_args: Any, stream: bool = False, **_kwargs: Any + ) -> ChatCompletionOutput | MockAsyncStream[MockStreamEvent]: + if stream or self.stream: + assert self.stream is not None, 'you can only use `stream=True` if `stream` is provided' + if isinstance(self.stream[0], Sequence): + response = MockAsyncStream(iter(cast(list[MockStreamEvent], self.stream[self.index]))) + else: + response = MockAsyncStream(iter(cast(list[MockStreamEvent], self.stream))) + else: + assert self.completions is not None, 'you can only use `stream=False` if `completions` are provided' + if isinstance(self.completions, Sequence): + raise_if_exception(self.completions[self.index]) + response = cast(ChatCompletionOutput, self.completions[self.index]) + else: + raise_if_exception(self.completions) + response = cast(ChatCompletionOutput, self.completions) + self.index += 1 + return response + + +def completion_message( + message: ChatCompletionInputMessage | ChatCompletionOutputMessage, *, usage: ChatCompletionOutputUsage | None = None +) -> ChatCompletionOutput: + choices = [ChatCompletionOutputComplete(finish_reason='stop', index=0, message=message)] # type:ignore + return ChatCompletionOutput.parse_obj_as_instance( # type: ignore + { + 'id': '123', + 'choices': choices, + 'created': 1704067200, # 2024-01-01 + 'model': 'hf-model', + 'object': 'chat.completion', + 'usage': usage, + } + ) + + +async def test_simple_completion(allow_model_requests: None): + c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type:ignore + mock_client = MockHuggingFace.create_mock(c) + model = HuggingFaceModel( + 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', hf_client=mock_client, api_key='x') + ) + agent = Agent(model) + + result = await agent.run('hello') + assert result.output == 'world' + messages = result.all_messages() + request = messages[0] + response = messages[1] + assert request.parts[0].content == 'hello' # type: ignore + assert response == ModelResponse( + parts=[TextPart(content='world')], + usage=Usage(requests=1), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', + ) + + +async def test_request_simple_usage(allow_model_requests: None): + c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type:ignore + mock_client = MockHuggingFace.create_mock(c) + model = HuggingFaceModel( + 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', hf_client=mock_client, api_key='x') + ) + agent = Agent(model) + + result = await agent.run('Hello') + assert result.output == 'world' + assert result.usage() == snapshot(Usage(requests=1)) + + +async def test_request_structured_response(allow_model_requests: None): + tool_call = ChatCompletionOutputToolCall.parse_obj_as_instance( # type:ignore + { + 'function': ChatCompletionOutputFunctionDefinition.parse_obj_as_instance( # type:ignore + { + 'name': 'final_result', + 'arguments': '{"response": [1, 2, 123]}', + } + ), + 'id': '123', + 'type': 'function', + } + ) + message = ChatCompletionOutputMessage.parse_obj_as_instance( # type:ignore + { + 'content': None, + 'role': 'assistant', + 'tool_calls': [tool_call], + } + ) + c = completion_message(message) + + mock_client = MockHuggingFace.create_mock(c) + model = HuggingFaceModel( + 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', hf_client=mock_client, api_key='x') + ) + agent = Agent(model, output_type=list[int]) + + result = await agent.run('Hello') + assert result.output == [1, 2, 123] + messages = result.all_messages() + assert messages[0].parts[0].content == 'Hello' # type: ignore + assert messages[1] == ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args='{"response": [1, 2, 123]}', + tool_call_id='123', + ) + ], + usage=Usage(requests=1), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), + vendor_id='123', + ) + + +async def test_stream_completion(allow_model_requests: None): + stream = [text_chunk('hello '), text_chunk('world', finish_reason='stop')] + mock_client = MockHuggingFace.create_stream_mock(stream) + model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(model) + + async with agent.run_stream('') as result: + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) + + +async def test_request_tool_call(allow_model_requests: None): + tool_call_1 = ChatCompletionOutputToolCall.parse_obj_as_instance( # type:ignore + { + 'function': ChatCompletionOutputFunctionDefinition.parse_obj_as_instance( # type:ignore + { + 'name': 'get_location', + 'arguments': '{"loc_name": "San Fransisco"}', + } + ), + 'id': '1', + 'type': 'function', + } + ) + usage_1 = ChatCompletionOutputUsage.parse_obj_as_instance( # type:ignore + { + 'prompt_tokens': 1, + 'completion_tokens': 1, + 'total_tokens': 2, + } + ) + tool_call_2 = ChatCompletionOutputToolCall.parse_obj_as_instance( # type:ignore + { + 'function': ChatCompletionOutputFunctionDefinition.parse_obj_as_instance( # type:ignore + { + 'name': 'get_location', + 'arguments': '{"loc_name": "London"}', + } + ), + 'id': '2', + 'type': 'function', + } + ) + usage_2 = ChatCompletionOutputUsage.parse_obj_as_instance( # type:ignore + { + 'prompt_tokens': 2, + 'completion_tokens': 1, + 'total_tokens': 3, + } + ) + responses = [ + completion_message( + ChatCompletionOutputMessage.parse_obj_as_instance( # type:ignore + { + 'content': None, + 'role': 'assistant', + 'tool_calls': [tool_call_1], + } + ), + usage=usage_1, + ), + completion_message( + ChatCompletionOutputMessage.parse_obj_as_instance( # type:ignore + { + 'content': None, + 'role': 'assistant', + 'tool_calls': [tool_call_2], + } + ), + usage=usage_2, + ), + completion_message( + ChatCompletionOutputMessage.parse_obj_as_instance( # type:ignore + { + 'content': 'final response', + 'role': 'assistant', + } + ), + ), + ] + mock_client = MockHuggingFace.create_mock(responses) + model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(model, system_prompt='this is the system prompt') + + @agent.tool_plain + async def get_location(loc_name: str) -> str: + if loc_name == 'London': + return json.dumps({'lat': 51, 'lng': 0}) + else: + raise ModelRetry('Wrong location, please try again') + + result = await agent.run('Hello') + assert result.output == 'final response' + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + SystemPromptPart(content='this is the system prompt', timestamp=IsNow(tz=timezone.utc)), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='get_location', + args='{"loc_name": "San Fransisco"}', + tool_call_id='1', + ) + ], + usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=2), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='Wrong location, please try again', + tool_name='get_location', + tool_call_id='1', + timestamp=IsNow(tz=timezone.utc), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='get_location', + args='{"loc_name": "London"}', + tool_call_id='2', + ) + ], + usage=Usage(requests=1, request_tokens=2, response_tokens=1, total_tokens=3), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_location', + content='{"lat": 51, "lng": 0}', + tool_call_id='2', + timestamp=IsNow(tz=timezone.utc), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='final response')], + usage=Usage(requests=1), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', + ), + ] + ) + + +FinishReason = Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'] + + +def chunk( + delta: list[ChatCompletionStreamOutputDelta], finish_reason: FinishReason | None = None +) -> ChatCompletionStreamOutput: + return ChatCompletionStreamOutput.parse_obj_as_instance( # type: ignore + { + 'id': 'x', + 'choices': [ + ChatCompletionStreamOutputChoice(index=index, delta=delta, finish_reason=finish_reason) # type: ignore + for index, delta in enumerate(delta) + ], + 'created': 1704067200, # 2024-01-01 + 'model': 'hf-model', + 'object': 'chat.completion.chunk', + 'usage': ChatCompletionStreamOutputUsage(completion_tokens=1, prompt_tokens=2, total_tokens=3), # type: ignore + } + ) + + +def text_chunk(text: str, finish_reason: FinishReason | None = None) -> ChatCompletionStreamOutput: + return chunk([ChatCompletionStreamOutputDelta(content=text, role='assistant')], finish_reason=finish_reason) # type: ignore + + +async def test_stream_text(allow_model_requests: None): + stream = [text_chunk('hello '), text_chunk('world'), chunk([])] + mock_client = MockHuggingFace.create_stream_mock(stream) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + agent = Agent(m) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) + assert result.is_complete + assert result.usage() == snapshot(Usage(requests=1, request_tokens=6, response_tokens=3, total_tokens=9)) + + +async def test_stream_text_finish_reason(allow_model_requests: None): + stream = [ + text_chunk('hello '), + text_chunk('world'), + text_chunk('.', finish_reason='stop'), + ] + mock_client = MockHuggingFace.create_stream_mock(stream) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + agent = Agent(m) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot( + ['hello ', 'hello world', 'hello world.'] + ) + assert result.is_complete + + +def struc_chunk( + tool_name: str | None, tool_arguments: str | None, finish_reason: FinishReason | None = None +) -> ChatCompletionStreamOutput: + return chunk( + [ + ChatCompletionStreamOutputDelta.parse_obj_as_instance( # type: ignore + { + 'role': 'assistant', + 'tool_calls': [ + ChatCompletionStreamOutputDeltaToolCall.parse_obj_as_instance( # type: ignore + { + 'index': 0, + 'function': ChatCompletionStreamOutputFunction.parse_obj_as_instance( # type: ignore + { + 'name': tool_name, + 'arguments': tool_arguments, + } + ), + } + ) + ], + } + ), + ], + finish_reason=finish_reason, + ) + + +class MyTypedDict(TypedDict, total=False): + first: str + second: str + + +async def test_stream_structured(allow_model_requests: None): + stream = [ + chunk([ChatCompletionStreamOutputDelta(role='assistant')]), # type: ignore + chunk([ChatCompletionStreamOutputDelta(role='assistant', tool_calls=[])]), # type: ignore + chunk( + [ + ChatCompletionStreamOutputDelta( + role='assistant', # type: ignore + tool_calls=[ # type: ignore + ChatCompletionStreamOutputDeltaToolCall(id='0', type='function', index=0, function=None) # type: ignore + ], + ) + ] + ), + chunk( + [ + ChatCompletionStreamOutputDelta( + role='assistant', # type: ignore + tool_calls=[ # type: ignore + ChatCompletionStreamOutputDeltaToolCall(id='0', type='function', index=0, function=None) # type: ignore + ], + ) + ] + ), + struc_chunk('final_result', None), + chunk( + [ + ChatCompletionStreamOutputDelta( + role='assistant', # type: ignore + tool_calls=[ # type: ignore + ChatCompletionStreamOutputDeltaToolCall(id='0', type='function', index=0, function=None) # type: ignore + ], + ) + ] + ), + struc_chunk(None, '{"first": "One'), + struc_chunk(None, '", "second": "Two"'), + struc_chunk(None, '}'), + chunk([]), + ] + mock_client = MockHuggingFace.create_stream_mock(stream) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + agent = Agent(m, output_type=MyTypedDict) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot( + [ + {}, + {'first': 'One'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + ] + ) + assert result.is_complete + assert result.usage() == snapshot(Usage(requests=1, request_tokens=20, response_tokens=10, total_tokens=30)) + # double check usage matches stream count + assert result.usage().response_tokens == len(stream) + + +async def test_stream_structured_finish_reason(allow_model_requests: None): + stream = [ + struc_chunk('final_result', None), + struc_chunk(None, '{"first": "One'), + struc_chunk(None, '", "second": "Two"'), + struc_chunk(None, '}'), + struc_chunk(None, None, finish_reason='stop'), + ] + mock_client = MockHuggingFace.create_stream_mock(stream) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + agent = Agent(m, output_type=MyTypedDict) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot( + [ + {'first': 'One'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + ] + ) + assert result.is_complete + + +async def test_no_content(allow_model_requests: None): + stream = [ + chunk([ChatCompletionStreamOutputDelta(role='assistant')]), # type: ignore + chunk([ChatCompletionStreamOutputDelta(role='assistant')]), # type: ignore + ] + mock_client = MockHuggingFace.create_stream_mock(stream) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + agent = Agent(m, output_type=MyTypedDict) + + with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'): + async with agent.run_stream(''): + pass + + +async def test_no_delta(allow_model_requests: None): + stream = [ + chunk([]), + text_chunk('hello '), + text_chunk('world'), + ] + mock_client = MockHuggingFace.create_stream_mock(stream) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + agent = Agent(m) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) + assert result.is_complete + assert result.usage() == snapshot(Usage(requests=1, request_tokens=6, response_tokens=3, total_tokens=9)) + + +async def test_image_url_input(allow_model_requests: None): + c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type:ignore + mock_client = MockHuggingFace.create_mock(c) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + agent = Agent(m) + + result = await agent.run( + [ + 'hello', + ImageUrl(url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg'), + ] + ) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content=[ + 'hello', + ImageUrl( + url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg' + ), + ], + timestamp=IsNow(tz=timezone.utc), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='world')], + usage=Usage(requests=1), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', + ), + ] + ) + + +async def test_image_as_binary_content_input(allow_model_requests: None): + c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type: ignore + mock_client = MockHuggingFace.create_mock(c) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + agent = Agent(m) + + base64_content = ( + b'/9j/4AAQSkZJRgABAQEAYABgAAD/4QBYRXhpZgAATU0AKgAAAAgAA1IBAAEAAAABAAAAPgIBAAEAAAABAAAARgMBAAEAAAABAAAA' + b'WgAAAAAAAAAE' + ) + + result = await agent.run(['hello', BinaryContent(data=base64_content, media_type='image/jpeg')]) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content=['hello', BinaryContent(data=base64_content, media_type='image/jpeg')], + timestamp=IsNow(tz=timezone.utc), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='world')], + usage=Usage(requests=1), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', + ), + ] + ) + + +def test_model_status_error(allow_model_requests: None) -> None: + error = HfHubHTTPError(message='test_error', response=Mock(status_code=500, content={'error': 'test error'})) + mock_client = MockHuggingFace.create_mock(error) + m = HuggingFaceModel('not_a_model', provider=HuggingFaceProvider(hf_client=mock_client)) + agent = Agent(m) + with pytest.raises(ModelHTTPError) as exc_info: + agent.run_sync('hello') + assert str(exc_info.value) == snapshot("status_code: 500, model_name: not_a_model, body: {'error': 'test error'}") + + +@pytest.mark.vcr() +async def test_request_simple_success_with_vcr(allow_model_requests: None, huggingface_api_key: str): + m = HuggingFaceModel( + 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', api_key=huggingface_api_key) + ) + agent = Agent(m) + result = await agent.run('hello') + assert result.output == snapshot( + "Hello! It's great to meet you. How can I assist you today? Whether you have questions, need information, or just want to chat, I'm here to help!" + ) + + +@pytest.mark.vcr() +async def test_hf_model_instructions(allow_model_requests: None, huggingface_api_key: str): + m = HuggingFaceModel( + 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', api_key=huggingface_api_key) + ) + + def simple_instructions(ctx: RunContext): + return 'You are a helpful assistant.' + + agent = Agent(m, instructions=simple_instructions) + + result = await agent.run('What is the capital of France?') + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='What is the capital of France?', timestamp=IsDatetime())], + instructions='You are a helpful assistant.', + ), + ModelResponse( + parts=[TextPart(content='Paris')], + usage=Usage(requests=1, request_tokens=26, response_tokens=2, total_tokens=28), + model_name='Qwen/Qwen2.5-72B-Instruct-fast', + timestamp=IsDatetime(), + vendor_id='chatcmpl-54246cfb4fa046e88a984020c4efab20', + ), + ] + ) diff --git a/tests/providers/test_huggingface.py b/tests/providers/test_huggingface.py new file mode 100644 index 0000000000..5b52bfe233 --- /dev/null +++ b/tests/providers/test_huggingface.py @@ -0,0 +1,61 @@ +from __future__ import annotations as _annotations + +import re + +import httpx +import pytest + +from pydantic_ai.exceptions import UserError + +from ..conftest import TestEnv, try_import + +with try_import() as imports_successful: + from huggingface_hub import AsyncInferenceClient + + from pydantic_ai.providers.huggingface import HuggingFaceProvider + + +pytestmark = pytest.mark.skipif(not imports_successful(), reason='huggingface_hub not installed') + + +def test_huggingface_provider(): + hf_client = AsyncInferenceClient(api_key='api-key') + provider = HuggingFaceProvider(api_key='api-key', hf_client=hf_client) + assert provider.name == 'huggingface' + assert isinstance(provider.client, AsyncInferenceClient) + assert provider.client.token == 'api-key' + + +def test_huggingface_provider_need_api_key(env: TestEnv) -> None: + env.remove('HF_TOKEN') + with pytest.raises( + UserError, + match=re.escape( + 'Set the `HF_TOKEN` environment variable or pass it via `HuggingFaceProvider(api_key=...)`' + 'to use the HuggingFace provider.' + ), + ): + HuggingFaceProvider() + + +def test_huggingface_provider_pass_http_client() -> None: + http_client = httpx.AsyncClient() + with pytest.raises( + ValueError, + match=re.escape('`http_client` is ignored for HuggingFace provider, please use `hf_client` instead'), + ): + HuggingFaceProvider(http_client=http_client, api_key='api-key') + + +def test_huggingface_provider_pass_hf_client() -> None: + hf_client = AsyncInferenceClient(api_key='api-key') + provider = HuggingFaceProvider(hf_client=hf_client) + assert provider.client == hf_client + + +def test_hf_provider_with_base_url() -> None: + # Test with environment variable for base_url + provider = HuggingFaceProvider( + hf_client=AsyncInferenceClient(api_key='test-api-key', base_url='https://router.huggingface.co/nebius/v1'), + ) + assert provider.base_url == 'https://router.huggingface.co/nebius/v1' From 40aef2e6f97f1581d5bb4c419852c6a5af4e698f Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 9 Jun 2025 13:49:44 +0100 Subject: [PATCH 06/27] add docs and known models for hf --- docs/models/huggingface.md | 84 +++++++++++++++++++ .../pydantic_ai/models/__init__.py | 14 +++- .../pydantic_ai/models/huggingface.py | 23 ++++- tests/models/test_model_names.py | 3 + 4 files changed, 121 insertions(+), 3 deletions(-) create mode 100644 docs/models/huggingface.md diff --git a/docs/models/huggingface.md b/docs/models/huggingface.md new file mode 100644 index 0000000000..8edeb8cc3b --- /dev/null +++ b/docs/models/huggingface.md @@ -0,0 +1,84 @@ +# Hugging Face + + +## Install + +To use `HuggingFace`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `huggingface` optional group: + +```bash +pip/uv-add "pydantic-ai-slim[huggingface]" +``` + +## Configuration + +To use `HuggingFaceModel` through their main API, go to [Inference Providers documentation](https://huggingface.co/docs/inference-providers/pricing) for all the details, and you can generate a Hugging Face Token here: https://huggingface.co/settings/tokens. + +## Environment variable + +Once you have a HuggingFace Token, you can set it as an environment variable: + +```bash +export HF_TOKEN='your-hf-token' +``` + +You can then use `HuggingFaceModel` by name: + +```python +from pydantic_ai import Agent + +agent = Agent('huggingface:Qwen/Qwen3-235B-A22B') +... +``` + +Or initialise the model directly with just the model name: + +```python +from pydantic_ai import Agent +from pydantic_ai.models.huggingface import HuggingFaceModel + +model = HuggingFaceModel('Qwen/Qwen3-235B-A22B') +agent = Agent(model) +... +``` + +By default, the `HuggingFaceModel` uses the `HuggingFaceProvider` that will select automatically the first of the inference providers (Cerebras, Together AI, Cohere..etc) available for the model, sorted by your preferred order in https://hf.co/settings/inference-providers. + +## Configure the provider + +If you want to pass parameters in code to the provider, you can programmatically instantiate the +[HuggingFaceProvider][pydantic_ai.providers.huggingface.HuggingFaceProvider] and pass it to the model: + +```python +from pydantic_ai import Agent +from pydantic_ai.models.huggingface import HuggingFaceModel +from pydantic_ai.providers.huggingface import HuggingFaceProvider + +model = HuggingFaceModel('Qwen/Qwen3-235B-A22B', provider=HuggingFaceProvider(api_key='your-api-key', provider="nebius")) +agent = Agent(model) +... +``` + +## Custom Hugging Face Client + +`HuggingFaceProvider` also accepts a custom `AsyncInferenceClient` client via the `hf_client` parameter, so you can customise the `headers`, `bill_to` (billing to an HF organization you're a member of), `base_url` etc. as defined in the [Hugging Face Hub python library docs](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client). + +```python +from huggingface_hub import AsyncInferenceClient + +from pydantic_ai import Agent +from pydantic_ai.models.huggingface import HuggingFaceModel +from pydantic_ai.providers.huggingface import HuggingFaceProvider + +client = AsyncInferenceClient( + bill_to="openai", + api_key='your-api-key', + provider="fireworks-ai", +) + +model = HuggingFaceModel( + 'Qwen/Qwen3-235B-A22B', + provider=HuggingFaceProvider(hf_client=client), +) +agent = Agent(model) +... +``` diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 92b919181d..e42be9dfce 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -204,6 +204,14 @@ 'groq:llama-3.2-3b-preview', 'groq:llama-3.2-11b-vision-preview', 'groq:llama-3.2-90b-vision-preview', + 'huggingface:Qwen/QwQ-32B', + 'huggingface:Qwen/Qwen2.5-72B-Instruct', + 'huggingface:Qwen/Qwen3-235B-A22B', + 'huggingface:Qwen/Qwen3-32B', + 'huggingface:deepseek-ai/DeepSeek-R1', + 'huggingface:meta-llama/Llama-3.3-70B-Instruct', + 'huggingface:meta-llama/Llama-4-Maverick-17B-128E-Instruct', + 'huggingface:meta-llama/Llama-4-Scout-17B-16E-Instruct', 'mistral:codestral-latest', 'mistral:mistral-large-latest', 'mistral:mistral-moderation-latest', @@ -485,7 +493,7 @@ def override_allow_model_requests(allow_model_requests: bool) -> Iterator[None]: ALLOW_MODEL_REQUESTS = old_value # pyright: ignore[reportConstantRedefinition] -def infer_model(model: Model | KnownModelName | str) -> Model: +def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901 """Infer the model from the name.""" if isinstance(model, Model): return model @@ -539,6 +547,10 @@ def infer_model(model: Model | KnownModelName | str) -> Model: from .bedrock import BedrockConverseModel return BedrockConverseModel(model_name, provider=provider) + elif provider == 'huggingface': + from .huggingface import HuggingFaceModel + + return HuggingFaceModel(model_name, provider=provider) else: raise UserError(f'Unknown model: {model}') # pragma: no cover diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 1dc1db73ca..5de42a2708 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -5,7 +5,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime, timezone -from typing import Literal, cast, overload +from typing import Literal, Union, cast, overload from typing_extensions import assert_never @@ -65,6 +65,25 @@ HFSystemPromptRole = Literal['system', 'user'] +LatestHuggingFaceModelNames = Literal[ + 'deepseek-ai/DeepSeek-R1', + 'meta-llama/Llama-3.3-70B-Instruct', + 'meta-llama/Llama-4-Maverick-17B-128E-Instruct', + 'meta-llama/Llama-4-Scout-17B-16E-Instruct', + 'Qwen/QwQ-32B', + 'Qwen/Qwen2.5-72B-Instruct', + 'Qwen/Qwen3-235B-A22B', + 'Qwen/Qwen3-32B', +] +"""Latest Hugging Face models.""" + + +HuggingFaceModelName = Union[str, LatestHuggingFaceModelNames] +"""Possible Hugging Face model names. + +You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending). +""" + class HuggingFaceModelSettings(ModelSettings, total=False): """Settings used for a Hugging Face model request. @@ -136,7 +155,7 @@ async def request_stream( yield await self._process_streamed_response(response) @property - def model_name(self) -> str: + def model_name(self) -> HuggingFaceModelName: """The model name.""" return self._model_name diff --git a/tests/models/test_model_names.py b/tests/models/test_model_names.py index 63eb7c650d..da53d4dd98 100644 --- a/tests/models/test_model_names.py +++ b/tests/models/test_model_names.py @@ -14,6 +14,7 @@ from pydantic_ai.models.cohere import CohereModelName from pydantic_ai.models.gemini import GeminiModelName from pydantic_ai.models.groq import GroqModelName + from pydantic_ai.models.huggingface import HuggingFaceModelName from pydantic_ai.models.mistral import MistralModelName from pydantic_ai.models.openai import OpenAIModelName @@ -44,6 +45,7 @@ def get_model_names(model_name_type: Any) -> Iterator[str]: ] bedrock_names = [f'bedrock:{n}' for n in get_model_names(BedrockModelName)] deepseek_names = ['deepseek:deepseek-chat', 'deepseek:deepseek-reasoner'] + huggingface_names = [f'huggingface:{n}' for n in get_model_names(HuggingFaceModelName)] extra_names = ['test'] generated_names = sorted( @@ -55,6 +57,7 @@ def get_model_names(model_name_type: Any) -> Iterator[str]: + openai_names + bedrock_names + deepseek_names + + huggingface_names + extra_names ) From 7a4b9a4ce3f0f0a5b8816d3e38251db37b52e502 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 9 Jun 2025 14:10:47 +0100 Subject: [PATCH 07/27] fix imports in test --- tests/models/test_huggingface.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index 378ab675a0..1fd8b63ea4 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -9,13 +9,6 @@ from unittest.mock import Mock import pytest -from huggingface_hub import ( - ChatCompletionStreamOutputChoice, - ChatCompletionStreamOutputDelta, - ChatCompletionStreamOutputDeltaToolCall, - ChatCompletionStreamOutputFunction, - ChatCompletionStreamOutputUsage, -) from inline_snapshot import snapshot from typing_extensions import TypedDict @@ -50,6 +43,11 @@ ChatCompletionOutputToolCall, ChatCompletionOutputUsage, ChatCompletionStreamOutput, + ChatCompletionStreamOutputChoice, + ChatCompletionStreamOutputDelta, + ChatCompletionStreamOutputDeltaToolCall, + ChatCompletionStreamOutputFunction, + ChatCompletionStreamOutputUsage, ) from huggingface_hub.errors import HfHubHTTPError From a1530818c6328161f0d892a352c1fec894901e6e Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 9 Jun 2025 14:17:54 +0100 Subject: [PATCH 08/27] fix tests --- tests/models/test_huggingface.py | 18 +++++++++--------- tests/providers/test_huggingface.py | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index 1fd8b63ea4..731728b7b1 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -389,7 +389,7 @@ def text_chunk(text: str, finish_reason: FinishReason | None = None) -> ChatComp async def test_stream_text(allow_model_requests: None): stream = [text_chunk('hello '), text_chunk('world'), chunk([])] mock_client = MockHuggingFace.create_stream_mock(stream) - m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) agent = Agent(m) async with agent.run_stream('') as result: @@ -406,7 +406,7 @@ async def test_stream_text_finish_reason(allow_model_requests: None): text_chunk('.', finish_reason='stop'), ] mock_client = MockHuggingFace.create_stream_mock(stream) - m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) agent = Agent(m) async with agent.run_stream('') as result: @@ -491,7 +491,7 @@ async def test_stream_structured(allow_model_requests: None): chunk([]), ] mock_client = MockHuggingFace.create_stream_mock(stream) - m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) agent = Agent(m, output_type=MyTypedDict) async with agent.run_stream('') as result: @@ -520,7 +520,7 @@ async def test_stream_structured_finish_reason(allow_model_requests: None): struc_chunk(None, None, finish_reason='stop'), ] mock_client = MockHuggingFace.create_stream_mock(stream) - m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) agent = Agent(m, output_type=MyTypedDict) async with agent.run_stream('') as result: @@ -543,7 +543,7 @@ async def test_no_content(allow_model_requests: None): chunk([ChatCompletionStreamOutputDelta(role='assistant')]), # type: ignore ] mock_client = MockHuggingFace.create_stream_mock(stream) - m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) agent = Agent(m, output_type=MyTypedDict) with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'): @@ -558,7 +558,7 @@ async def test_no_delta(allow_model_requests: None): text_chunk('world'), ] mock_client = MockHuggingFace.create_stream_mock(stream) - m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) agent = Agent(m) async with agent.run_stream('') as result: @@ -571,7 +571,7 @@ async def test_no_delta(allow_model_requests: None): async def test_image_url_input(allow_model_requests: None): c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type:ignore mock_client = MockHuggingFace.create_mock(c) - m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) agent = Agent(m) result = await agent.run( @@ -609,7 +609,7 @@ async def test_image_url_input(allow_model_requests: None): async def test_image_as_binary_content_input(allow_model_requests: None): c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type: ignore mock_client = MockHuggingFace.create_mock(c) - m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client)) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) agent = Agent(m) base64_content = ( @@ -642,7 +642,7 @@ async def test_image_as_binary_content_input(allow_model_requests: None): def test_model_status_error(allow_model_requests: None) -> None: error = HfHubHTTPError(message='test_error', response=Mock(status_code=500, content={'error': 'test error'})) mock_client = MockHuggingFace.create_mock(error) - m = HuggingFaceModel('not_a_model', provider=HuggingFaceProvider(hf_client=mock_client)) + m = HuggingFaceModel('not_a_model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) agent = Agent(m) with pytest.raises(ModelHTTPError) as exc_info: agent.run_sync('hello') diff --git a/tests/providers/test_huggingface.py b/tests/providers/test_huggingface.py index 5b52bfe233..4df8d8f193 100644 --- a/tests/providers/test_huggingface.py +++ b/tests/providers/test_huggingface.py @@ -56,6 +56,6 @@ def test_huggingface_provider_pass_hf_client() -> None: def test_hf_provider_with_base_url() -> None: # Test with environment variable for base_url provider = HuggingFaceProvider( - hf_client=AsyncInferenceClient(api_key='test-api-key', base_url='https://router.huggingface.co/nebius/v1'), + hf_client=AsyncInferenceClient(base_url='https://router.huggingface.co/nebius/v1'), api_key='test-api-key' ) assert provider.base_url == 'https://router.huggingface.co/nebius/v1' From 2f0ec5189dcd0607e24a2b61f15e745a4f835bef Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 9 Jun 2025 14:20:06 +0100 Subject: [PATCH 09/27] fix provider test --- tests/providers/test_huggingface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/test_huggingface.py b/tests/providers/test_huggingface.py index 4df8d8f193..970c9d6366 100644 --- a/tests/providers/test_huggingface.py +++ b/tests/providers/test_huggingface.py @@ -49,7 +49,7 @@ def test_huggingface_provider_pass_http_client() -> None: def test_huggingface_provider_pass_hf_client() -> None: hf_client = AsyncInferenceClient(api_key='api-key') - provider = HuggingFaceProvider(hf_client=hf_client) + provider = HuggingFaceProvider(hf_client=hf_client, api_key='api-key') assert provider.client == hf_client From 69aee552602f6c382677a50d52410a83aaa4a9f0 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 9 Jun 2025 14:21:16 +0100 Subject: [PATCH 10/27] adapt cli test --- tests/test_cli.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_cli.py b/tests/test_cli.py index 024116249c..8efc0da005 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -144,6 +144,7 @@ def test_list_models(capfd: CaptureFixture[str]): 'cohere', 'deepseek', 'heroku', + 'huggingface', ) models = {line.strip().split(' ')[0] for line in output[3:]} for provider in providers: From f68dacea3ba1315abe9e463548313c635b351334 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 9 Jun 2025 14:26:51 +0100 Subject: [PATCH 11/27] re-record vcr cassettes --- .../test_hf_model_instructions.yaml | 70 +------------------ .../test_request_simple_success_with_vcr.yaml | 14 ++-- tests/models/test_huggingface.py | 4 +- 3 files changed, 11 insertions(+), 77 deletions(-) diff --git a/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml b/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml index f621f4c4f8..11bcb75969 100644 --- a/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml +++ b/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml @@ -1,70 +1,4 @@ interactions: -- request: - body: null - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - method: GET - uri: https://huggingface.co/api/models/Qwen/Qwen2.5-72B-Instruct?expand=inferenceProviderMapping - response: - headers: - access-control-allow-origin: - - https://huggingface.co - access-control-expose-headers: - - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash - connection: - - keep-alive - content-length: - - '800' - content-type: - - application/json; charset=utf-8 - cross-origin-opener-policy: - - same-origin - etag: - - W/"320-IoLwHc4XKGzRoHW0ok1gY7tY/NI" - referrer-policy: - - strict-origin-when-cross-origin - vary: - - Origin - parsed_body: - _id: 66e81cefd1b1391042d0e47e - id: Qwen/Qwen2.5-72B-Instruct - inferenceProviderMapping: - featherless-ai: - providerId: Qwen/Qwen2.5-72B-Instruct - status: error - task: conversational - fireworks-ai: - providerId: accounts/fireworks/models/qwen2p5-72b-instruct - status: live - task: conversational - hf-inference: - providerId: Qwen/Qwen2.5-72B-Instruct - status: live - task: conversational - hyperbolic: - providerId: Qwen/Qwen2.5-72B-Instruct - status: live - task: conversational - nebius: - providerId: Qwen/Qwen2.5-72B-Instruct-fast - status: live - task: conversational - novita: - providerId: qwen/qwen-2.5-72b-instruct - status: live - task: conversational - together: - providerId: Qwen/Qwen2.5-72B-Instruct-Turbo - status: live - task: conversational - status: - code: 200 - message: OK - request: body: null headers: {} @@ -106,8 +40,8 @@ interactions: role: assistant tool_calls: [] stop_reason: null - created: 1749227878 - id: chatcmpl-54246cfb4fa046e88a984020c4efab20 + created: 1749475551 + id: chatcmpl-6fa46f85f4f04beda9c936d5996b22a8 model: Qwen/Qwen2.5-72B-Instruct-fast object: chat.completion prompt_logprobs: null diff --git a/tests/models/cassettes/test_huggingface/test_request_simple_success_with_vcr.yaml b/tests/models/cassettes/test_huggingface/test_request_simple_success_with_vcr.yaml index c9a3b50f2a..6996da0333 100644 --- a/tests/models/cassettes/test_huggingface/test_request_simple_success_with_vcr.yaml +++ b/tests/models/cassettes/test_huggingface/test_request_simple_success_with_vcr.yaml @@ -81,7 +81,7 @@ interactions: connection: - keep-alive content-length: - - '700' + - '680' content-type: - application/json cross-origin-opener-policy: @@ -99,27 +99,27 @@ interactions: logprobs: null message: audio: null - content: Hello! It's great to meet you. How can I assist you today? Whether you have questions, need information, - or just want to chat, I'm here to help! + content: Hello! How can I assist you today? Feel free to ask me any questions or let me know if you need help with + anything specific. function_call: null reasoning_content: null refusal: null role: assistant tool_calls: [] stop_reason: null - created: 1749226637 - id: chatcmpl-f5783ce357b4415b8d59dbbf5b3cf9bf + created: 1749475549 + id: chatcmpl-6050852c70164258bb9bab4e93e2b69c model: Qwen/Qwen2.5-72B-Instruct-fast object: chat.completion prompt_logprobs: null service_tier: null system_fingerprint: null usage: - completion_tokens: 37 + completion_tokens: 29 completion_tokens_details: null prompt_tokens: 30 prompt_tokens_details: null - total_tokens: 67 + total_tokens: 59 status: code: 200 message: OK diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index 731728b7b1..384328adff 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -657,7 +657,7 @@ async def test_request_simple_success_with_vcr(allow_model_requests: None, huggi agent = Agent(m) result = await agent.run('hello') assert result.output == snapshot( - "Hello! It's great to meet you. How can I assist you today? Whether you have questions, need information, or just want to chat, I'm here to help!" + 'Hello! How can I assist you today? Feel free to ask me any questions or let me know if you need help with anything specific.' ) @@ -684,7 +684,7 @@ def simple_instructions(ctx: RunContext): usage=Usage(requests=1, request_tokens=26, response_tokens=2, total_tokens=28), model_name='Qwen/Qwen2.5-72B-Instruct-fast', timestamp=IsDatetime(), - vendor_id='chatcmpl-54246cfb4fa046e88a984020c4efab20', + vendor_id='chatcmpl-6fa46f85f4f04beda9c936d5996b22a8', ), ] ) From cc982e5271ce10ae147d34c12e8068df8437fae2 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 9 Jun 2025 14:35:11 +0100 Subject: [PATCH 12/27] fix token name --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index a2c730a2a9..1240bab6f3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -283,7 +283,7 @@ def openrouter_api_key() -> str: @pytest.fixture(scope='session') def huggingface_api_key() -> str: - return os.getenv('HF_TOKEN', 'mock-api-key') or os.getenv('HUGGINGFACE_API_KEY', 'mock-api-key') + return os.getenv('HF_TOKEN', 'hf_token') or os.getenv('HUGGINGFACE_API_KEY', 'hf_token') @pytest.fixture(scope='session') From 00da46ecf0cf92374df4c5a82b714c279bd67ff3 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 9 Jun 2025 14:50:23 +0100 Subject: [PATCH 13/27] fix examples test --- docs/models/huggingface.md | 10 +++++----- tests/test_examples.py | 1 + 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/models/huggingface.md b/docs/models/huggingface.md index 8edeb8cc3b..8d10a7ea8d 100644 --- a/docs/models/huggingface.md +++ b/docs/models/huggingface.md @@ -18,7 +18,7 @@ To use `HuggingFaceModel` through their main API, go to [Inference Providers doc Once you have a HuggingFace Token, you can set it as an environment variable: ```bash -export HF_TOKEN='your-hf-token' +export HF_TOKEN='hf_token' ``` You can then use `HuggingFaceModel` by name: @@ -53,7 +53,7 @@ from pydantic_ai import Agent from pydantic_ai.models.huggingface import HuggingFaceModel from pydantic_ai.providers.huggingface import HuggingFaceProvider -model = HuggingFaceModel('Qwen/Qwen3-235B-A22B', provider=HuggingFaceProvider(api_key='your-api-key', provider="nebius")) +model = HuggingFaceModel('Qwen/Qwen3-235B-A22B', provider=HuggingFaceProvider(api_key='hf_token', provider='nebius')) agent = Agent(model) ... ``` @@ -70,9 +70,9 @@ from pydantic_ai.models.huggingface import HuggingFaceModel from pydantic_ai.providers.huggingface import HuggingFaceProvider client = AsyncInferenceClient( - bill_to="openai", - api_key='your-api-key', - provider="fireworks-ai", + bill_to='openai', + api_key='hf_token', + provider='fireworks-ai', ) model = HuggingFaceModel( diff --git a/tests/test_examples.py b/tests/test_examples.py index ad377bedbf..977f336f03 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -137,6 +137,7 @@ def print(self, *args: Any, **kwargs: Any) -> None: env.set('CO_API_KEY', 'testing') env.set('MISTRAL_API_KEY', 'testing') env.set('ANTHROPIC_API_KEY', 'testing') + env.set('HF_TOKEN', 'hf_testing') env.set('AWS_ACCESS_KEY_ID', 'testing') env.set('AWS_SECRET_ACCESS_KEY', 'testing') env.set('AWS_DEFAULT_REGION', 'us-east-1') From 922fd13161f3a06cd6f95f43d28cb207e278b33e Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 25 Jun 2025 12:53:40 +0200 Subject: [PATCH 14/27] Add API docs and refactor a bit the wording --- docs/api/models/huggingface.md | 7 +++++++ docs/api/providers.md | 2 ++ docs/models/huggingface.md | 27 +++++++++++++++++---------- mkdocs.yml | 1 + 4 files changed, 27 insertions(+), 10 deletions(-) create mode 100644 docs/api/models/huggingface.md diff --git a/docs/api/models/huggingface.md b/docs/api/models/huggingface.md new file mode 100644 index 0000000000..72e78c4a3e --- /dev/null +++ b/docs/api/models/huggingface.md @@ -0,0 +1,7 @@ +# `pydantic_ai.models.huggingface` + +## Setup + +For details on how to set up authentication with this model, see [model configuration for Hugging Face](../../models/huggingface.md). + +::: pydantic_ai.models.huggingface diff --git a/docs/api/providers.md b/docs/api/providers.md index 926cf8e8b1..8e808185a8 100644 --- a/docs/api/providers.md +++ b/docs/api/providers.md @@ -29,3 +29,5 @@ ::: pydantic_ai.providers.heroku.HerokuProvider ::: pydantic_ai.providers.openrouter.OpenRouterProvider + +::: pydantic_ai.providers.huggingface.HuggingFaceProvider diff --git a/docs/models/huggingface.md b/docs/models/huggingface.md index 8d10a7ea8d..e99a77f00c 100644 --- a/docs/models/huggingface.md +++ b/docs/models/huggingface.md @@ -1,9 +1,8 @@ # Hugging Face - ## Install -To use `HuggingFace`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `huggingface` optional group: +To use `HuggingFaceModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `huggingface` optional group: ```bash pip/uv-add "pydantic-ai-slim[huggingface]" @@ -11,17 +10,19 @@ pip/uv-add "pydantic-ai-slim[huggingface]" ## Configuration -To use `HuggingFaceModel` through their main API, go to [Inference Providers documentation](https://huggingface.co/docs/inference-providers/pricing) for all the details, and you can generate a Hugging Face Token here: https://huggingface.co/settings/tokens. +To use [HuggingFace](https://huggingface.co/) through their main API, go to +[Inference Providers documentation](https://huggingface.co/docs/inference-providers/pricing) for all the details, +and you can generate a Hugging Face access token here: https://huggingface.co/settings/tokens. -## Environment variable +## Hugging Face access token -Once you have a HuggingFace Token, you can set it as an environment variable: +Once you have a Hugging Face access token, you can set it as an environment variable: ```bash export HF_TOKEN='hf_token' ``` -You can then use `HuggingFaceModel` by name: +You can then use [`HuggingFaceModel`][pydantic_ai.models.huggingface.HuggingFaceModel] by name: ```python from pydantic_ai import Agent @@ -41,12 +42,15 @@ agent = Agent(model) ... ``` -By default, the `HuggingFaceModel` uses the `HuggingFaceProvider` that will select automatically the first of the inference providers (Cerebras, Together AI, Cohere..etc) available for the model, sorted by your preferred order in https://hf.co/settings/inference-providers. +By default, the [`HuggingFaceModel`][pydantic_ai.models.huggingface.HuggingFaceModel] uses the +[`HuggingFaceProvider`][pydantic_ai.providers.huggingface.HuggingFaceProvider] that will select automatically +the first of the inference providers (Cerebras, Together AI, Cohere..etc) available for the model, sorted by your +preferred order in https://hf.co/settings/inference-providers. ## Configure the provider If you want to pass parameters in code to the provider, you can programmatically instantiate the -[HuggingFaceProvider][pydantic_ai.providers.huggingface.HuggingFaceProvider] and pass it to the model: +[`HuggingFaceProvider`][pydantic_ai.providers.huggingface.HuggingFaceProvider] and pass it to the model: ```python from pydantic_ai import Agent @@ -58,9 +62,12 @@ agent = Agent(model) ... ``` -## Custom Hugging Face Client +## Custom Hugging Face client -`HuggingFaceProvider` also accepts a custom `AsyncInferenceClient` client via the `hf_client` parameter, so you can customise the `headers`, `bill_to` (billing to an HF organization you're a member of), `base_url` etc. as defined in the [Hugging Face Hub python library docs](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client). +[`HuggingFaceProvider`][pydantic_ai.providers.huggingface.HuggingFaceProvider] also accepts a custom +[`AsyncInferenceClient`][huggingface_hub.AsyncInferenceClient] client via the `hf_client` parameter, so you can customise +the `headers`, `bill_to` (billing to an HF organization you're a member of), `base_url` etc. as defined in the +[Hugging Face Hub python library docs](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client). ```python from huggingface_hub import AsyncInferenceClient diff --git a/mkdocs.yml b/mkdocs.yml index d750c29bbd..55fd86384e 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -81,6 +81,7 @@ nav: - api/models/gemini.md - api/models/google.md - api/models/groq.md + - api/models/huggingface.md - api/models/instrumented.md - api/models/mistral.md - api/models/test.md From adfc2548918e804f7896adece711ac56c9ea7178 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Wed, 2 Jul 2025 17:57:10 +0200 Subject: [PATCH 15/27] review suggestions --- docs/models/huggingface.md | 2 +- .../pydantic_ai/providers/huggingface.py | 28 ++++++-- tests/conftest.py | 4 +- .../test_hf_model_instructions.yaml | 66 ++++++++++++++++++- tests/models/test_huggingface.py | 15 +++-- tests/providers/test_huggingface.py | 2 +- 6 files changed, 98 insertions(+), 19 deletions(-) diff --git a/docs/models/huggingface.md b/docs/models/huggingface.md index e99a77f00c..6425d76cb5 100644 --- a/docs/models/huggingface.md +++ b/docs/models/huggingface.md @@ -57,7 +57,7 @@ from pydantic_ai import Agent from pydantic_ai.models.huggingface import HuggingFaceModel from pydantic_ai.providers.huggingface import HuggingFaceProvider -model = HuggingFaceModel('Qwen/Qwen3-235B-A22B', provider=HuggingFaceProvider(api_key='hf_token', provider='nebius')) +model = HuggingFaceModel('Qwen/Qwen3-235B-A22B', provider=HuggingFaceProvider(api_key='hf_token', provider_name='nebius')) agent = Agent(model) ... ``` diff --git a/pydantic_ai_slim/pydantic_ai/providers/huggingface.py b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py index e18a60d16c..8afb415914 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py @@ -1,6 +1,7 @@ from __future__ import annotations as _annotations import os +from typing import overload from httpx import AsyncClient @@ -32,13 +33,26 @@ def base_url(self) -> str: def client(self) -> AsyncInferenceClient: return self._client + @overload + def __init__(self, *, base_url: str, api_key: str | None = None) -> None: ... + @overload + def __init__(self, *, provider_name: str, api_key: str | None = None) -> None: ... + @overload + def __init__(self, *, hf_client: AsyncInferenceClient, api_key: str | None = None) -> None: ... + @overload + def __init__(self, *, hf_client: AsyncInferenceClient, base_url: str, api_key: str | None = None) -> None: ... + @overload + def __init__(self, *, hf_client: AsyncInferenceClient, provider_name: str, api_key: str | None = None) -> None: ... + @overload + def __init__(self, *, api_key: str | None = None) -> None: ... + def __init__( self, base_url: str | None = None, api_key: str | None = None, hf_client: AsyncInferenceClient | None = None, http_client: AsyncClient | None = None, - provider: str | None = None, + provider_name: str | None = None, ) -> None: """Create a new Hugging Face provider. @@ -50,9 +64,9 @@ def __init__( [`AsyncInferenceClient`](https://huggingface.co/docs/huggingface_hub/v0.29.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient) client to use. If not provided, a new instance will be created. http_client: (currently ignored) An existing `httpx.AsyncClient` to use for making HTTP requests. - provider : Name of the provider to use for inference. available providers can be found in the [HF Inference Providers documentation](https://huggingface.co/docs/inference-providers/index#partners). + provider_name : Name of the provider to use for inference. available providers can be found in the [HF Inference Providers documentation](https://huggingface.co/docs/inference-providers/index#partners). defaults to "auto", which will select the first available provider for the model, the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. - If `base_url` is passed, then `provider` is not used. + If `base_url` is passed, then `provider_name` is not used. """ api_key = api_key or os.environ.get('HF_TOKEN') @@ -63,12 +77,12 @@ def __init__( ) if http_client is not None: - raise ValueError('`http_client` is ignored for HuggingFace provider, please use `hf_client` instead') + raise ValueError('`http_client` is ignored for HuggingFace provider, please use `hf_client` instead.') - if base_url is not None and provider is not None: - raise ValueError('Cannot provide both `base_url` and `provider`') + if base_url is not None and provider_name is not None: + raise ValueError('Cannot provide both `base_url` and `provider_name`.') if hf_client is None: - self._client = AsyncInferenceClient(api_key=api_key, provider=provider, base_url=base_url) # type: ignore + self._client = AsyncInferenceClient(api_key=api_key, provider=provider_name, base_url=base_url) # type: ignore else: self._client = hf_client diff --git a/tests/conftest.py b/tests/conftest.py index 6cfc627bd3..73c6e07cc4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -294,7 +294,7 @@ def openrouter_api_key() -> str: @pytest.fixture(scope='session') def huggingface_api_key() -> str: - return os.getenv('HF_TOKEN', 'hf_token') or os.getenv('HUGGINGFACE_API_KEY', 'hf_token') + return os.getenv('HF_TOKEN', 'hf_token') @pytest.fixture(scope='session') @@ -428,7 +428,7 @@ def model( return HuggingFaceModel( 'Qwen/Qwen2.5-72B-Instruct', - provider=HuggingFaceProvider(provider='nebius', api_key=huggingface_api_key), + provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key), ) else: raise ValueError(f'Unknown model: {request.param}') diff --git a/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml b/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml index 11bcb75969..d8a5ee07e3 100644 --- a/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml +++ b/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml @@ -1,4 +1,66 @@ interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/Qwen/Qwen2.5-72B-Instruct?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '701' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"2bd-diYmxjldwbIbFgWNRPBqJ3SEIak" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 66e81cefd1b1391042d0e47e + id: Qwen/Qwen2.5-72B-Instruct + inferenceProviderMapping: + featherless-ai: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + fireworks-ai: + providerId: accounts/fireworks/models/qwen2p5-72b-instruct + status: live + task: conversational + hyperbolic: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + nebius: + providerId: Qwen/Qwen2.5-72B-Instruct-fast + status: live + task: conversational + novita: + providerId: qwen/qwen-2.5-72b-instruct + status: live + task: conversational + together: + providerId: Qwen/Qwen2.5-72B-Instruct-Turbo + status: live + task: conversational + status: + code: 200 + message: OK - request: body: null headers: {} @@ -40,8 +102,8 @@ interactions: role: assistant tool_calls: [] stop_reason: null - created: 1749475551 - id: chatcmpl-6fa46f85f4f04beda9c936d5996b22a8 + created: 1751470757 + id: chatcmpl-b3936940372c481b8d886e596dc75524 model: Qwen/Qwen2.5-72B-Instruct-fast object: chat.completion prompt_logprobs: null diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index 384328adff..cae1e2bfeb 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -125,7 +125,8 @@ async def test_simple_completion(allow_model_requests: None): c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type:ignore mock_client = MockHuggingFace.create_mock(c) model = HuggingFaceModel( - 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', hf_client=mock_client, api_key='x') + 'Qwen/Qwen2.5-72B-Instruct', + provider=HuggingFaceProvider(provider_name='nebius', hf_client=mock_client, api_key='x'), ) agent = Agent(model) @@ -148,7 +149,8 @@ async def test_request_simple_usage(allow_model_requests: None): c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type:ignore mock_client = MockHuggingFace.create_mock(c) model = HuggingFaceModel( - 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', hf_client=mock_client, api_key='x') + 'Qwen/Qwen2.5-72B-Instruct', + provider=HuggingFaceProvider(provider_name='nebius', hf_client=mock_client, api_key='x'), ) agent = Agent(model) @@ -181,7 +183,8 @@ async def test_request_structured_response(allow_model_requests: None): mock_client = MockHuggingFace.create_mock(c) model = HuggingFaceModel( - 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', hf_client=mock_client, api_key='x') + 'Qwen/Qwen2.5-72B-Instruct', + provider=HuggingFaceProvider(provider_name='nebius', hf_client=mock_client, api_key='x'), ) agent = Agent(model, output_type=list[int]) @@ -652,7 +655,7 @@ def test_model_status_error(allow_model_requests: None) -> None: @pytest.mark.vcr() async def test_request_simple_success_with_vcr(allow_model_requests: None, huggingface_api_key: str): m = HuggingFaceModel( - 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', api_key=huggingface_api_key) + 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key) ) agent = Agent(m) result = await agent.run('hello') @@ -664,7 +667,7 @@ async def test_request_simple_success_with_vcr(allow_model_requests: None, huggi @pytest.mark.vcr() async def test_hf_model_instructions(allow_model_requests: None, huggingface_api_key: str): m = HuggingFaceModel( - 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider='nebius', api_key=huggingface_api_key) + 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key) ) def simple_instructions(ctx: RunContext): @@ -684,7 +687,7 @@ def simple_instructions(ctx: RunContext): usage=Usage(requests=1, request_tokens=26, response_tokens=2, total_tokens=28), model_name='Qwen/Qwen2.5-72B-Instruct-fast', timestamp=IsDatetime(), - vendor_id='chatcmpl-6fa46f85f4f04beda9c936d5996b22a8', + vendor_id='chatcmpl-b3936940372c481b8d886e596dc75524', ), ] ) diff --git a/tests/providers/test_huggingface.py b/tests/providers/test_huggingface.py index 970c9d6366..944d418a01 100644 --- a/tests/providers/test_huggingface.py +++ b/tests/providers/test_huggingface.py @@ -44,7 +44,7 @@ def test_huggingface_provider_pass_http_client() -> None: ValueError, match=re.escape('`http_client` is ignored for HuggingFace provider, please use `hf_client` instead'), ): - HuggingFaceProvider(http_client=http_client, api_key='api-key') + HuggingFaceProvider(http_client=http_client, api_key='api-key') # type: ignore def test_huggingface_provider_pass_hf_client() -> None: From e4af59eee94bb68253b11c4a2ee5db0a050daac1 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 7 Jul 2025 17:27:17 +0200 Subject: [PATCH 16/27] more tests --- tests/providers/test_huggingface.py | 69 +++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/tests/providers/test_huggingface.py b/tests/providers/test_huggingface.py index 944d418a01..9c6074af7d 100644 --- a/tests/providers/test_huggingface.py +++ b/tests/providers/test_huggingface.py @@ -1,6 +1,7 @@ from __future__ import annotations as _annotations import re +from unittest.mock import MagicMock, Mock, patch import httpx import pytest @@ -59,3 +60,71 @@ def test_hf_provider_with_base_url() -> None: hf_client=AsyncInferenceClient(base_url='https://router.huggingface.co/nebius/v1'), api_key='test-api-key' ) assert provider.base_url == 'https://router.huggingface.co/nebius/v1' + + +def test_huggingface_provider_properties(): + mock_client = Mock(spec=AsyncInferenceClient) + mock_client.model = 'test-model' + provider = HuggingFaceProvider(hf_client=mock_client) + assert provider.name == 'huggingface' + assert provider.base_url == 'test-model' + assert provider.client is mock_client + + +def test_huggingface_provider_init_api_key_error(monkeypatch: pytest.MonkeyPatch): + monkeypatch.delenv('HF_TOKEN', raising=False) + with pytest.raises(UserError, match='Set the `HF_TOKEN` environment variable'): + HuggingFaceProvider() + + +@patch('pydantic_ai.providers.huggingface.AsyncInferenceClient') +def test_huggingface_provider_init_api_key_from_env( + MockAsyncInferenceClient: MagicMock, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.setenv('HF_TOKEN', 'env-key') + HuggingFaceProvider() + MockAsyncInferenceClient.assert_called_with(api_key='env-key', provider=None, base_url=None) + + +@patch('pydantic_ai.providers.huggingface.AsyncInferenceClient') +def test_huggingface_provider_init_api_key_from_arg( + MockAsyncInferenceClient: MagicMock, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.setenv('HF_TOKEN', 'env-key') + HuggingFaceProvider(api_key='arg-key') + MockAsyncInferenceClient.assert_called_with(api_key='arg-key', provider=None, base_url=None) + + +def test_huggingface_provider_init_http_client_error(): + with pytest.raises(ValueError, match='`http_client` is ignored'): + HuggingFaceProvider(api_key='key', http_client=Mock()) # type: ignore[call-overload] + + +def test_huggingface_provider_init_base_url_and_provider_name_error(): + with pytest.raises(ValueError, match='Cannot provide both `base_url` and `provider_name`'): + HuggingFaceProvider(api_key='key', base_url='url', provider_name='provider') # type: ignore[call-overload] + + +def test_huggingface_provider_init_with_hf_client(): + mock_client = Mock(spec=AsyncInferenceClient) + provider = HuggingFaceProvider(hf_client=mock_client) + assert provider.client is mock_client + + +@patch('pydantic_ai.providers.huggingface.AsyncInferenceClient') +def test_huggingface_provider_init_without_hf_client(MockAsyncInferenceClient: MagicMock): + provider = HuggingFaceProvider(api_key='key') + assert provider.client is MockAsyncInferenceClient.return_value + MockAsyncInferenceClient.assert_called_with(api_key='key', provider=None, base_url=None) + + +@patch('pydantic_ai.providers.huggingface.AsyncInferenceClient') +def test_huggingface_provider_init_with_provider_name(MockAsyncInferenceClient: MagicMock): + HuggingFaceProvider(api_key='key', provider_name='test-provider') + MockAsyncInferenceClient.assert_called_once_with(api_key='key', provider='test-provider', base_url=None) + + +@patch('pydantic_ai.providers.huggingface.AsyncInferenceClient') +def test_huggingface_provider_init_with_base_url(MockAsyncInferenceClient: MagicMock): + HuggingFaceProvider(api_key='key', base_url='test-url') + MockAsyncInferenceClient.assert_called_once_with(api_key='key', provider=None, base_url='test-url') From cd76d7871003a8e37a9bcccbd9c5af1fb492cd7a Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 7 Jul 2025 17:34:06 +0200 Subject: [PATCH 17/27] fix test --- tests/providers/test_huggingface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/test_huggingface.py b/tests/providers/test_huggingface.py index 9c6074af7d..858e049cb7 100644 --- a/tests/providers/test_huggingface.py +++ b/tests/providers/test_huggingface.py @@ -107,7 +107,7 @@ def test_huggingface_provider_init_base_url_and_provider_name_error(): def test_huggingface_provider_init_with_hf_client(): mock_client = Mock(spec=AsyncInferenceClient) - provider = HuggingFaceProvider(hf_client=mock_client) + provider = HuggingFaceProvider(hf_client=mock_client, api_key='key') assert provider.client is mock_client From 13ebbf9d5e4da1c128f358cecf28a2e1ad2b7fb6 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Mon, 7 Jul 2025 17:35:13 +0200 Subject: [PATCH 18/27] fix another test --- tests/providers/test_huggingface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/test_huggingface.py b/tests/providers/test_huggingface.py index 858e049cb7..c9d263ec66 100644 --- a/tests/providers/test_huggingface.py +++ b/tests/providers/test_huggingface.py @@ -65,7 +65,7 @@ def test_hf_provider_with_base_url() -> None: def test_huggingface_provider_properties(): mock_client = Mock(spec=AsyncInferenceClient) mock_client.model = 'test-model' - provider = HuggingFaceProvider(hf_client=mock_client) + provider = HuggingFaceProvider(hf_client=mock_client, api_key='test-api-key') assert provider.name == 'huggingface' assert provider.base_url == 'test-model' assert provider.client is mock_client From b96b7107d1d887feafb0dba2c89ff51ab512b423 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Wed, 9 Jul 2025 11:02:18 +0200 Subject: [PATCH 19/27] add more vcr tests --- .../pydantic_ai/models/huggingface.py | 2 - pydantic_ai_slim/pyproject.toml | 2 +- .../test_image_as_binary_content_input.yaml | 106 ++++++ .../test_image_url_input.yaml | 105 ++++++ ...ion_tokens[Qwen-Qwen2.5-72B-Instruct].yaml | 122 +++++++ ..._tokens[deepseek-ai-DeepSeek-R1-0528].yaml | 128 +++++++ ...ns[meta-llama-Llama-3.3-70B-Instruct].yaml | 142 ++++++++ .../test_request_simple_usage.yaml | 122 +++++++ .../test_simple_completion.yaml | 122 +++++++ .../test_stream_completion.yaml | 319 ++++++++++++++++++ tests/models/test_huggingface.py | 166 ++++++--- tests/providers/test_huggingface.py | 14 +- uv.lock | 33 +- 13 files changed, 1307 insertions(+), 76 deletions(-) create mode 100644 tests/models/cassettes/test_huggingface/test_image_as_binary_content_input.yaml create mode 100644 tests/models/cassettes/test_huggingface/test_image_url_input.yaml create mode 100644 tests/models/cassettes/test_huggingface/test_max_completion_tokens[Qwen-Qwen2.5-72B-Instruct].yaml create mode 100644 tests/models/cassettes/test_huggingface/test_max_completion_tokens[deepseek-ai-DeepSeek-R1-0528].yaml create mode 100644 tests/models/cassettes/test_huggingface/test_max_completion_tokens[meta-llama-Llama-3.3-70B-Instruct].yaml create mode 100644 tests/models/cassettes/test_huggingface/test_request_simple_usage.yaml create mode 100644 tests/models/cassettes/test_huggingface/test_simple_completion.yaml create mode 100644 tests/models/cassettes/test_huggingface/test_stream_completion.yaml diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index bee2926ec1..52744e0bdd 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -291,8 +291,6 @@ async def _map_messages( texts: list[str] = [] tool_calls: list[ChatCompletionInputToolCall] = [] for item in message.parts: - if isinstance(item, ThinkingPart): - continue if isinstance(item, TextPart): texts.append(item.content) elif isinstance(item, ToolCallPart): diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 998dda16f2..90872c6362 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -69,7 +69,7 @@ anthropic = ["anthropic>=0.52.0"] groq = ["groq>=0.19.0"] mistral = ["mistralai>=1.2.5"] bedrock = ["boto3>=1.37.24"] -huggingface = ["huggingface-hub[inference]>=0.32.0"] +huggingface = ["huggingface-hub[inference]>=0.33.2"] # Tools duckduckgo = ["duckduckgo-search>=7.0.0"] tavily = ["tavily-python>=0.5.0"] diff --git a/tests/models/cassettes/test_huggingface/test_image_as_binary_content_input.yaml b/tests/models/cassettes/test_huggingface/test_image_as_binary_content_input.yaml new file mode 100644 index 0000000000..8b295d4404 --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_image_as_binary_content_input.yaml @@ -0,0 +1,106 @@ +interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/Qwen/Qwen2.5-VL-72B-Instruct?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '293' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"125-DEMuQsKZBCb9/68jW5UsI3Q7x7E" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 6797079422990ae89b5aff86 + id: Qwen/Qwen2.5-VL-72B-Instruct + inferenceProviderMapping: + hyperbolic: + providerId: Qwen/Qwen2.5-VL-72B-Instruct + status: live + task: conversational + nebius: + providerId: Qwen/Qwen2.5-VL-72B-Instruct + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '776' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: The fruit in the image is a kiwi. It has been sliced in half, revealing its bright green flesh with small + black seeds arranged in a circular pattern around a white center. The outer skin of the kiwi is fuzzy and brown. + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: null + created: 1751986733 + id: chatcmpl-bd957b950cce4d61839e2af25f56f684 + model: Qwen/Qwen2.5-VL-72B-Instruct + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 49 + completion_tokens_details: null + prompt_tokens: 7625 + prompt_tokens_details: null + total_tokens: 7674 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_huggingface/test_image_url_input.yaml b/tests/models/cassettes/test_huggingface/test_image_url_input.yaml new file mode 100644 index 0000000000..791a0aede5 --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_image_url_input.yaml @@ -0,0 +1,105 @@ +interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/Qwen/Qwen2.5-VL-72B-Instruct?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '293' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"125-DEMuQsKZBCb9/68jW5UsI3Q7x7E" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 6797079422990ae89b5aff86 + id: Qwen/Qwen2.5-VL-72B-Instruct + inferenceProviderMapping: + hyperbolic: + providerId: Qwen/Qwen2.5-VL-72B-Instruct + status: live + task: conversational + nebius: + providerId: Qwen/Qwen2.5-VL-72B-Instruct + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '612' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: Hello! How can I assist you with this image of a potato? + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: null + created: 1751983479 + id: chatcmpl-49aa100effab4ca28514d5ccc00d7944 + model: Qwen/Qwen2.5-VL-72B-Instruct + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 15 + completion_tokens_details: null + prompt_tokens: 269 + prompt_tokens_details: null + total_tokens: 284 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_huggingface/test_max_completion_tokens[Qwen-Qwen2.5-72B-Instruct].yaml b/tests/models/cassettes/test_huggingface/test_max_completion_tokens[Qwen-Qwen2.5-72B-Instruct].yaml new file mode 100644 index 0000000000..8395c16fc6 --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_max_completion_tokens[Qwen-Qwen2.5-72B-Instruct].yaml @@ -0,0 +1,122 @@ +interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/Qwen/Qwen2.5-72B-Instruct?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '704' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"2c0-CGiQuUurY/UiBTJC7RlRRjJtbZU" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 66e81cefd1b1391042d0e47e + id: Qwen/Qwen2.5-72B-Instruct + inferenceProviderMapping: + featherless-ai: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + fireworks-ai: + providerId: accounts/fireworks/models/qwen2p5-72b-instruct + status: error + task: conversational + hyperbolic: + providerId: Qwen/Qwen2.5-72B-Instruct + status: error + task: conversational + nebius: + providerId: Qwen/Qwen2.5-72B-Instruct-fast + status: live + task: conversational + novita: + providerId: qwen/qwen-2.5-72b-instruct + status: error + task: conversational + together: + providerId: Qwen/Qwen2.5-72B-Instruct-Turbo + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '693' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: Hello! How can I assist you today? Whether you have questions, need help with something specific, or just + want to chat, I'm here to help! + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: null + created: 1752050598 + id: chatcmpl-5295b41092674918b860d41f723660cb + model: Qwen/Qwen2.5-72B-Instruct-fast + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 33 + completion_tokens_details: null + prompt_tokens: 30 + prompt_tokens_details: null + total_tokens: 63 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_huggingface/test_max_completion_tokens[deepseek-ai-DeepSeek-R1-0528].yaml b/tests/models/cassettes/test_huggingface/test_max_completion_tokens[deepseek-ai-DeepSeek-R1-0528].yaml new file mode 100644 index 0000000000..6f9868de9b --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_max_completion_tokens[deepseek-ai-DeepSeek-R1-0528].yaml @@ -0,0 +1,128 @@ +interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/deepseek-ai/DeepSeek-R1-0528?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '678' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"2a6-gQg+B654Px2F2NUtLDU93uSoBDU" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 6836db82a3626cb7b5343be8 + id: deepseek-ai/DeepSeek-R1-0528 + inferenceProviderMapping: + fireworks-ai: + providerId: accounts/fireworks/models/deepseek-r1-0528 + status: live + task: conversational + hyperbolic: + providerId: deepseek-ai/DeepSeek-R1-0528 + status: live + task: conversational + nebius: + providerId: deepseek-ai/DeepSeek-R1-0528 + status: live + task: conversational + novita: + providerId: deepseek/deepseek-r1-0528 + status: live + task: conversational + sambanova: + providerId: DeepSeek-R1-0528 + status: live + task: conversational + together: + providerId: deepseek-ai/DeepSeek-R1 + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '1325' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: "\nOkay, the user just said “hello”. A simple greeting. They might be testing if I'm online, starting + a casual chat, or preparing a deeper question. \n\nSince they didn't add context, I'll match their tone—friendly + and open-ended. Short response invites them to lead. Adding the emoji makes it warmer. No need to overthink yet. + \n\nHmm… if they're new, they might need reassurance that I'm responsive. If they're regular users, they're probably + just warming up. Either way, keeping it light feels safe. \n\nWatch for clues in their next message—if they dive + into a topic, they were just being polite before asking. If they reply with small talk, they might want companionship.\n\nHello! + \U0001F60A How can I assist you today?" + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: null + created: 1752050599 + id: chatcmpl-25472217e5b643e0a1f3f20dd44ed2c1 + kv_transfer_params: null + model: deepseek-ai/DeepSeek-R1-0528 + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 165 + completion_tokens_details: null + prompt_tokens: 6 + prompt_tokens_details: null + total_tokens: 171 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_huggingface/test_max_completion_tokens[meta-llama-Llama-3.3-70B-Instruct].yaml b/tests/models/cassettes/test_huggingface/test_max_completion_tokens[meta-llama-Llama-3.3-70B-Instruct].yaml new file mode 100644 index 0000000000..101f8f9e22 --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_max_completion_tokens[meta-llama-Llama-3.3-70B-Instruct].yaml @@ -0,0 +1,142 @@ +interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/meta-llama/Llama-3.3-70B-Instruct?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '1215' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"4bf-2c5rXKFDCLWF+O3TnkXoII8pC2U" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 6745f28f9333dfcc06268b1e + id: meta-llama/Llama-3.3-70B-Instruct + inferenceProviderMapping: + cerebras: + providerId: llama-3.3-70b + status: live + task: conversational + featherless-ai: + providerId: meta-llama/Llama-3.3-70B-Instruct + status: live + task: conversational + fireworks-ai: + providerId: accounts/fireworks/models/llama-v3p3-70b-instruct + status: live + task: conversational + groq: + providerId: llama-3.3-70b-versatile + status: live + task: conversational + hyperbolic: + providerId: meta-llama/Llama-3.3-70B-Instruct + status: live + task: conversational + nebius: + providerId: meta-llama/Llama-3.3-70B-Instruct-fast + status: live + task: conversational + novita: + providerId: meta-llama/llama-3.3-70b-instruct + status: live + task: conversational + nscale: + providerId: meta-llama/Llama-3.3-70B-Instruct + status: live + task: conversational + ovhcloud: + providerId: Meta-Llama-3_3-70B-Instruct + status: error + task: conversational + sambanova: + providerId: Meta-Llama-3.3-70B-Instruct + status: live + task: conversational + together: + providerId: meta-llama/Llama-3.3-70B-Instruct-Turbo + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '686' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: '{"type": "function", "name": "print_output", "parameters": {"output": "hello"}}' + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: 128008 + created: 1752050609 + id: chatcmpl-e4e88c8a58b34ea8bd5c47e6265a0de3 + kv_transfer_params: null + model: meta-llama/Llama-3.3-70B-Instruct-fast + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 23 + completion_tokens_details: null + prompt_tokens: 92 + prompt_tokens_details: null + total_tokens: 115 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_huggingface/test_request_simple_usage.yaml b/tests/models/cassettes/test_huggingface/test_request_simple_usage.yaml new file mode 100644 index 0000000000..4025ce48a1 --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_request_simple_usage.yaml @@ -0,0 +1,122 @@ +interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/Qwen/Qwen2.5-72B-Instruct?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '703' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"2bf-bkSLwumMG89/DZCsDWwBvtIEsEs" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 66e81cefd1b1391042d0e47e + id: Qwen/Qwen2.5-72B-Instruct + inferenceProviderMapping: + featherless-ai: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + fireworks-ai: + providerId: accounts/fireworks/models/qwen2p5-72b-instruct + status: live + task: conversational + hyperbolic: + providerId: Qwen/Qwen2.5-72B-Instruct + status: error + task: conversational + nebius: + providerId: Qwen/Qwen2.5-72B-Instruct-fast + status: live + task: conversational + novita: + providerId: qwen/qwen-2.5-72b-instruct + status: error + task: conversational + together: + providerId: Qwen/Qwen2.5-72B-Instruct-Turbo + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '712' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: Hello! It's great to meet you. How can I assist you today? Whether you have any questions, need some advice, + or just want to chat, feel free to let me know! + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: null + created: 1751982062 + id: chatcmpl-f366f315c05040fd9c4a505b516bce4b + model: Qwen/Qwen2.5-72B-Instruct-fast + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 40 + completion_tokens_details: null + prompt_tokens: 30 + prompt_tokens_details: null + total_tokens: 70 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_huggingface/test_simple_completion.yaml b/tests/models/cassettes/test_huggingface/test_simple_completion.yaml new file mode 100644 index 0000000000..a5f1d979ec --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_simple_completion.yaml @@ -0,0 +1,122 @@ +interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/Qwen/Qwen2.5-72B-Instruct?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '703' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"2bf-bkSLwumMG89/DZCsDWwBvtIEsEs" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 66e81cefd1b1391042d0e47e + id: Qwen/Qwen2.5-72B-Instruct + inferenceProviderMapping: + featherless-ai: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + fireworks-ai: + providerId: accounts/fireworks/models/qwen2p5-72b-instruct + status: live + task: conversational + hyperbolic: + providerId: Qwen/Qwen2.5-72B-Instruct + status: error + task: conversational + nebius: + providerId: Qwen/Qwen2.5-72B-Instruct-fast + status: live + task: conversational + novita: + providerId: qwen/qwen-2.5-72b-instruct + status: error + task: conversational + together: + providerId: Qwen/Qwen2.5-72B-Instruct-Turbo + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '680' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: Hello! How can I assist you today? Feel free to ask me any questions or let me know if you need help with + anything specific. + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: null + created: 1751982153 + id: chatcmpl-d445c0d473a84791af2acf356cc00df7 + model: Qwen/Qwen2.5-72B-Instruct-fast + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 29 + completion_tokens_details: null + prompt_tokens: 30 + prompt_tokens_details: null + total_tokens: 59 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_huggingface/test_stream_completion.yaml b/tests/models/cassettes/test_huggingface/test_stream_completion.yaml new file mode 100644 index 0000000000..e592d3f271 --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_stream_completion.yaml @@ -0,0 +1,319 @@ +interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/Qwen/Qwen2.5-72B-Instruct?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '703' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"2bf-bkSLwumMG89/DZCsDWwBvtIEsEs" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 66e81cefd1b1391042d0e47e + id: Qwen/Qwen2.5-72B-Instruct + inferenceProviderMapping: + featherless-ai: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + fireworks-ai: + providerId: accounts/fireworks/models/qwen2p5-72b-instruct + status: live + task: conversational + hyperbolic: + providerId: Qwen/Qwen2.5-72B-Instruct + status: error + task: conversational + nebius: + providerId: Qwen/Qwen2.5-72B-Instruct-fast + status: live + task: conversational + novita: + providerId: qwen/qwen-2.5-72b-instruct + status: error + task: conversational + together: + providerId: Qwen/Qwen2.5-72B-Instruct-Turbo + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + body: + string: |+ + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":"","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":"Hello"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":"!"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" It"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" seems"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" like"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" your"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" message"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" might"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" have"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" been"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" cut"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" off"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" or"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" not"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" fully"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" sent"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":"."},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" Could"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" you"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" please"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" provide"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" more"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" details"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" so"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" I"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" can"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" assist"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" you"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" better"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":"?"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":""},"finish_reason":"stop","index":0,"logprobs":null,"stop_reason":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: [DONE] + + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-type: + - text/event-stream; charset=utf-8 + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + transfer-encoding: + - chunked + vary: + - Origin + status: + code: 200 + message: OK +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/Qwen/Qwen2.5-72B-Instruct?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '703' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"2bf-bkSLwumMG89/DZCsDWwBvtIEsEs" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 66e81cefd1b1391042d0e47e + id: Qwen/Qwen2.5-72B-Instruct + inferenceProviderMapping: + featherless-ai: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + fireworks-ai: + providerId: accounts/fireworks/models/qwen2p5-72b-instruct + status: live + task: conversational + hyperbolic: + providerId: Qwen/Qwen2.5-72B-Instruct + status: error + task: conversational + nebius: + providerId: Qwen/Qwen2.5-72B-Instruct-fast + status: live + task: conversational + novita: + providerId: qwen/qwen-2.5-72b-instruct + status: error + task: conversational + together: + providerId: Qwen/Qwen2.5-72B-Instruct-Turbo + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + body: + string: |+ + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":"","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":"Hello"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":"!"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" How"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" can"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" I"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" assist"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" you"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" today"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":"?"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" Feel"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" free"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" to"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" ask"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" me"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" any"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" questions"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" or"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" let"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" me"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" know"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" if"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" you"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" need"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" help"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" with"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" anything"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" specific"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":"."},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":""},"finish_reason":"stop","index":0,"logprobs":null,"stop_reason":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: [DONE] + + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-type: + - text/event-stream; charset=utf-8 + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + transfer-encoding: + - chunked + vary: + - Origin + status: + code: 200 + message: OK +version: 1 +... diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index cae1e2bfeb..f2e346073b 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -27,9 +27,10 @@ UserPromptPart, ) from pydantic_ai.result import Usage +from pydantic_ai.settings import ModelSettings from pydantic_ai.tools import RunContext -from ..conftest import IsDatetime, IsNow, raise_if_exception, try_import +from ..conftest import IsDatetime, IsNow, IsStr, raise_if_exception, try_import from .mock_async_stream import MockAsyncStream with try_import() as imports_successful: @@ -60,6 +61,7 @@ pytestmark = [ pytest.mark.skipif(not imports_successful(), reason='huggingface_hub not installed'), pytest.mark.anyio, + pytest.mark.filterwarnings('ignore::ResourceWarning'), ] @@ -121,45 +123,55 @@ def completion_message( ) -async def test_simple_completion(allow_model_requests: None): - c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type:ignore - mock_client = MockHuggingFace.create_mock(c) +@pytest.mark.vcr() +async def test_simple_completion(allow_model_requests: None, huggingface_api_key: str): model = HuggingFaceModel( 'Qwen/Qwen2.5-72B-Instruct', - provider=HuggingFaceProvider(provider_name='nebius', hf_client=mock_client, api_key='x'), + provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key), ) agent = Agent(model) result = await agent.run('hello') - assert result.output == 'world' + assert ( + result.output + == 'Hello! How can I assist you today? Feel free to ask me any questions or let me know if you need help with anything specific.' + ) messages = result.all_messages() request = messages[0] response = messages[1] assert request.parts[0].content == 'hello' # type: ignore assert response == ModelResponse( - parts=[TextPart(content='world')], - usage=Usage(requests=1), - model_name='hf-model', - timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + parts=[ + TextPart( + content='Hello! How can I assist you today? Feel free to ask me any questions or let me know if you need help with anything specific.' + ) + ], + usage=Usage(requests=1, request_tokens=30, response_tokens=29, total_tokens=59), + model_name='Qwen/Qwen2.5-72B-Instruct-fast', + timestamp=datetime(2025, 7, 8, 13, 42, 33, tzinfo=timezone.utc), + vendor_id='chatcmpl-d445c0d473a84791af2acf356cc00df7', ) -async def test_request_simple_usage(allow_model_requests: None): - c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type:ignore - mock_client = MockHuggingFace.create_mock(c) +@pytest.mark.vcr() +async def test_request_simple_usage(allow_model_requests: None, huggingface_api_key: str): model = HuggingFaceModel( 'Qwen/Qwen2.5-72B-Instruct', - provider=HuggingFaceProvider(provider_name='nebius', hf_client=mock_client, api_key='x'), + provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key), ) agent = Agent(model) result = await agent.run('Hello') - assert result.output == 'world' - assert result.usage() == snapshot(Usage(requests=1)) + assert ( + result.output + == "Hello! It's great to meet you. How can I assist you today? Whether you have any questions, need some advice, or just want to chat, feel free to let me know!" + ) + assert result.usage() == snapshot(Usage(requests=1, request_tokens=30, response_tokens=40, total_tokens=70)) -async def test_request_structured_response(allow_model_requests: None): +async def test_request_structured_response( + allow_model_requests: None, +): tool_call = ChatCompletionOutputToolCall.parse_obj_as_instance( # type:ignore { 'function': ChatCompletionOutputFunctionDefinition.parse_obj_as_instance( # type:ignore @@ -571,10 +583,12 @@ async def test_no_delta(allow_model_requests: None): assert result.usage() == snapshot(Usage(requests=1, request_tokens=6, response_tokens=3, total_tokens=9)) -async def test_image_url_input(allow_model_requests: None): - c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type:ignore - mock_client = MockHuggingFace.create_mock(c) - m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) +@pytest.mark.vcr() +async def test_image_url_input(allow_model_requests: None, huggingface_api_key: str): + m = HuggingFaceModel( + 'Qwen/Qwen2.5-VL-72B-Instruct', + provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key), + ) agent = Agent(m) result = await agent.run( @@ -599,46 +613,28 @@ async def test_image_url_input(allow_model_requests: None): ] ), ModelResponse( - parts=[TextPart(content='world')], - usage=Usage(requests=1), - model_name='hf-model', - timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', + parts=[TextPart(content='Hello! How can I assist you with this image of a potato?')], + usage=Usage(requests=1, request_tokens=269, response_tokens=15, total_tokens=284), + model_name='Qwen/Qwen2.5-VL-72B-Instruct', + timestamp=datetime(2025, 7, 8, 14, 4, 39, tzinfo=timezone.utc), + vendor_id='chatcmpl-49aa100effab4ca28514d5ccc00d7944', ), ] ) -async def test_image_as_binary_content_input(allow_model_requests: None): - c = completion_message(ChatCompletionInputMessage(content='world', role='assistant')) # type: ignore - mock_client = MockHuggingFace.create_mock(c) - m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) - agent = Agent(m) - - base64_content = ( - b'/9j/4AAQSkZJRgABAQEAYABgAAD/4QBYRXhpZgAATU0AKgAAAAgAA1IBAAEAAAABAAAAPgIBAAEAAAABAAAARgMBAAEAAAABAAAA' - b'WgAAAAAAAAAE' +@pytest.mark.vcr() +async def test_image_as_binary_content_input( + allow_model_requests: None, image_content: BinaryContent, huggingface_api_key: str +): + m = HuggingFaceModel( + 'Qwen/Qwen2.5-VL-72B-Instruct', + provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key), ) - - result = await agent.run(['hello', BinaryContent(data=base64_content, media_type='image/jpeg')]) - assert result.all_messages() == snapshot( - [ - ModelRequest( - parts=[ - UserPromptPart( - content=['hello', BinaryContent(data=base64_content, media_type='image/jpeg')], - timestamp=IsNow(tz=timezone.utc), - ) - ] - ), - ModelResponse( - parts=[TextPart(content='world')], - usage=Usage(requests=1), - model_name='hf-model', - timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - vendor_id='123', - ), - ] + agent = Agent(m) + result = await agent.run(['What fruit is in the image?', image_content]) + assert result.output == snapshot( + 'The fruit in the image is a kiwi. It has been sliced in half, revealing its bright green flesh with small black seeds arranged in a circular pattern around a white center. The outer skin of the kiwi is fuzzy and brown.' ) @@ -691,3 +687,61 @@ def simple_instructions(ctx: RunContext): ), ] ) + + +@pytest.mark.parametrize( + 'model_name', ['Qwen/Qwen2.5-72B-Instruct', 'deepseek-ai/DeepSeek-R1-0528', 'meta-llama/Llama-3.3-70B-Instruct'] +) +@pytest.mark.vcr() +async def test_max_completion_tokens(allow_model_requests: None, model_name: str, huggingface_api_key: str): + m = HuggingFaceModel(model_name, provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key)) + agent = Agent(m, model_settings=ModelSettings(max_tokens=100)) + + result = await agent.run('hello') + assert result.output == IsStr() + + +def test_system_property(): + model = HuggingFaceModel('some-model') + assert model.system == 'huggingface' + + +async def test_model_client_response_error(allow_model_requests: None) -> None: + try: + import aiohttp + except ImportError: + pytest.skip('aiohttp is not installed') + + request_info = Mock(spec=aiohttp.RequestInfo) + request_info.url = 'http://test.com' + request_info.method = 'POST' + request_info.headers = {} + request_info.real_url = 'http://test.com' + error = aiohttp.ClientResponseError(request_info, history=(), status=400, message='Bad Request') + error.response_error_payload = {'error': 'test error'} # type: ignore + + mock_client = MockHuggingFace.create_mock(error) + m = HuggingFaceModel('not_a_model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(m) + with pytest.raises(ModelHTTPError) as exc_info: + await agent.run('hello') + assert str(exc_info.value) == snapshot("status_code: 400, model_name: not_a_model, body: {'error': 'test error'}") + + +async def test_process_response_no_created_timestamp(allow_model_requests: None): + c = completion_message( + ChatCompletionOutputMessage.parse_obj_as_instance({'content': 'response', 'role': 'assistant'}), # type: ignore + ) + c.created = None # type: ignore + + mock_client = MockHuggingFace.create_mock(c) + model = HuggingFaceModel( + 'test-model', + provider=HuggingFaceProvider(hf_client=mock_client, api_key='x'), + ) + agent = Agent(model) + result = await agent.run('Hello') + messages = result.all_messages() + response_message = messages[1] + assert isinstance(response_message, ModelResponse) + assert response_message.timestamp == IsNow(tz=timezone.utc) diff --git a/tests/providers/test_huggingface.py b/tests/providers/test_huggingface.py index c9d263ec66..c9570a54dc 100644 --- a/tests/providers/test_huggingface.py +++ b/tests/providers/test_huggingface.py @@ -67,7 +67,6 @@ def test_huggingface_provider_properties(): mock_client.model = 'test-model' provider = HuggingFaceProvider(hf_client=mock_client, api_key='test-api-key') assert provider.name == 'huggingface' - assert provider.base_url == 'test-model' assert provider.client is mock_client @@ -128,3 +127,16 @@ def test_huggingface_provider_init_with_provider_name(MockAsyncInferenceClient: def test_huggingface_provider_init_with_base_url(MockAsyncInferenceClient: MagicMock): HuggingFaceProvider(api_key='key', base_url='test-url') MockAsyncInferenceClient.assert_called_once_with(api_key='key', provider=None, base_url='test-url') + + +def test_huggingface_provider_init_api_key_is_none(monkeypatch: pytest.MonkeyPatch): + monkeypatch.delenv('HF_TOKEN', raising=False) + with pytest.raises(UserError): + HuggingFaceProvider(api_key=None) + + +def test_huggingface_provider_base_url(): + mock_client = Mock(spec=AsyncInferenceClient) + mock_client.model = 'test-model' + provider = HuggingFaceProvider(hf_client=mock_client, api_key='test-api-key') + assert provider.base_url == 'test-model' diff --git a/uv.lock b/uv.lock index e18ce4b7d1..e7e928c321 100644 --- a/uv.lock +++ b/uv.lock @@ -1322,6 +1322,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761", size = 58259, upload-time = "2022-09-25T15:39:59.68Z" }, ] +[[package]] +name = "h2" +version = "4.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "hpack" }, + { name = "hyperframe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1b/38/d7f80fd13e6582fb8e0df8c9a653dcc02b03ca34f4d72f34869298c5baf8/h2-4.2.0.tar.gz", hash = "sha256:c8a52129695e88b1a0578d8d2cc6842bbd79128ac685463b887ee278126ad01f", size = 2150682, upload-time = "2025-02-02T07:43:51.815Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/9e/984486f2d0a0bd2b024bf4bc1c62688fcafa9e61991f041fb0e2def4a982/h2-4.2.0-py3-none-any.whl", hash = "sha256:479a53ad425bb29af087f3458a61d30780bc818e4ebcf01f0b536ba916462ed0", size = 60957, upload-time = "2025-02-01T11:02:26.481Z" }, +] + [[package]] name = "hf-xet" version = "1.1.3" @@ -1336,18 +1349,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/53/e3/2fcec58d2fcfd25ff07feb876f466cfa11f8dcf9d3b742c07fe9dd51ee0a/hf_xet-1.1.3-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:7c1a6aa6abed1f696f8099aa9796ca04c9ee778a58728a115607de9cc4638ff1", size = 4970349, upload-time = "2025-06-04T00:47:25.383Z" }, { url = "https://files.pythonhosted.org/packages/53/bf/10ca917e335861101017ff46044c90e517b574fbb37219347b83be1952f6/hf_xet-1.1.3-cp37-abi3-win_amd64.whl", hash = "sha256:b578ae5ac9c056296bb0df9d018e597c8dc6390c5266f35b5c44696003cde9f3", size = 2310934, upload-time = "2025-06-04T00:47:29.632Z" }, ] -[[package]] -name = "h2" -version = "4.2.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "hpack" }, - { name = "hyperframe" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/1b/38/d7f80fd13e6582fb8e0df8c9a653dcc02b03ca34f4d72f34869298c5baf8/h2-4.2.0.tar.gz", hash = "sha256:c8a52129695e88b1a0578d8d2cc6842bbd79128ac685463b887ee278126ad01f", size = 2150682, upload-time = "2025-02-02T07:43:51.815Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d0/9e/984486f2d0a0bd2b024bf4bc1c62688fcafa9e61991f041fb0e2def4a982/h2-4.2.0-py3-none-any.whl", hash = "sha256:479a53ad425bb29af087f3458a61d30780bc818e4ebcf01f0b536ba916462ed0", size = 60957, upload-time = "2025-02-01T11:02:26.481Z" }, -] [[package]] name = "hpack" @@ -1397,7 +1398,7 @@ wheels = [ [[package]] name = "huggingface-hub" -version = "0.32.4" +version = "0.33.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -1409,9 +1410,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/60/c8/4f7d270285c46324fd66f62159eb16739aa5696f422dba57678a8c6b78e9/huggingface_hub-0.32.4.tar.gz", hash = "sha256:f61d45cd338736f59fb0e97550b74c24ee771bcc92c05ae0766b9116abe720be", size = 424494, upload-time = "2025-06-03T09:59:46.105Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fa/42/8a95c5632080ae312c0498744b2b852195e10b05a20b1be11c5141092f4c/huggingface_hub-0.33.2.tar.gz", hash = "sha256:84221defaec8fa09c090390cd68c78b88e3c4c2b7befba68d3dc5aacbc3c2c5f", size = 426637, upload-time = "2025-07-02T06:26:05.156Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/67/8b/222140f3cfb6f17b0dd8c4b9a0b36bd4ebefe9fb0098ba35d6960abcda0f/huggingface_hub-0.32.4-py3-none-any.whl", hash = "sha256:37abf8826b38d971f60d3625229221c36e53fe58060286db9baf619cfbf39767", size = 512101, upload-time = "2025-06-03T09:59:44.099Z" }, + { url = "https://files.pythonhosted.org/packages/44/f4/5f3f22e762ad1965f01122b42dae5bf0e009286e2dba601ce1d0dba72424/huggingface_hub-0.33.2-py3-none-any.whl", hash = "sha256:3749498bfa91e8cde2ddc2c1db92c79981f40e66434c20133b39e5928ac9bcc5", size = 515373, upload-time = "2025-07-02T06:26:03.072Z" }, ] [package.optional-dependencies] @@ -3174,7 +3175,7 @@ requires-dist = [ { name = "griffe", specifier = ">=1.3.2" }, { name = "groq", marker = "extra == 'groq'", specifier = ">=0.19.0" }, { name = "httpx", specifier = ">=0.27" }, - { name = "huggingface-hub", extras = ["inference"], marker = "extra == 'huggingface'", specifier = ">=0.32.0" }, + { name = "huggingface-hub", extras = ["inference"], marker = "extra == 'huggingface'", specifier = ">=0.33.2" }, { name = "logfire", marker = "extra == 'logfire'", specifier = ">=3.11.0" }, { name = "mcp", marker = "python_full_version >= '3.10' and extra == 'mcp'", specifier = ">=1.9.4" }, { name = "mistralai", marker = "extra == 'mistral'", specifier = ">=1.2.5" }, From 9b0edeee85986666852d70fa8d8eb9265c1b2f92 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Wed, 9 Jul 2025 11:09:44 +0200 Subject: [PATCH 20/27] split thinking part --- pydantic_ai_slim/pydantic_ai/models/huggingface.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 52744e0bdd..405e991fab 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -9,6 +9,7 @@ from typing_extensions import assert_never +from pydantic_ai._thinking_part import split_content_into_text_and_thinking from pydantic_ai.providers import Provider, infer_provider from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage @@ -242,12 +243,15 @@ def _process_response(self, response: ChatCompletionOutput) -> ModelResponse: timestamp = _now_utc() choice = response.choices[0] + content = choice.message.content + tool_calls = choice.message.tool_calls + items: list[ModelResponsePart] = [] - if choice.message.content is not None: - items.append(TextPart(choice.message.content)) - if choice.message.tool_calls is not None: - for c in choice.message.tool_calls: + if content is not None: + items.extend(split_content_into_text_and_thinking(content)) + if tool_calls is not None: + for c in tool_calls: items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id)) return ModelResponse( items, From 6da1cf287b3f674017f7ed4dc7c030437d1b25ca Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Wed, 9 Jul 2025 11:13:35 +0200 Subject: [PATCH 21/27] fix tests --- tests/models/test_huggingface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index f2e346073b..ee19e6fff4 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -702,7 +702,7 @@ async def test_max_completion_tokens(allow_model_requests: None, model_name: str def test_system_property(): - model = HuggingFaceModel('some-model') + model = HuggingFaceModel('some-model', provider=HuggingFaceProvider(hf_client=Mock(), api_key='x')) assert model.system == 'huggingface' From d546e0453cd1fe95ad1641391d9dedfb72de8f20 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Wed, 9 Jul 2025 11:59:32 +0200 Subject: [PATCH 22/27] add more context to hugging face models page --- docs/models/huggingface.md | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/docs/models/huggingface.md b/docs/models/huggingface.md index 6425d76cb5..a03e6e5a93 100644 --- a/docs/models/huggingface.md +++ b/docs/models/huggingface.md @@ -1,5 +1,7 @@ # Hugging Face +[Hugging Face](https://huggingface.co/) is an AI platform with all major open source models, datasets, MCPs, and demos. You can use [Inference Providers](https://huggingface.co/docs/inference-providers) to run open source models like DeepSeek R1 on scalable serverless infrastructure. + ## Install To use `HuggingFaceModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `huggingface` optional group: @@ -10,11 +12,11 @@ pip/uv-add "pydantic-ai-slim[huggingface]" ## Configuration -To use [HuggingFace](https://huggingface.co/) through their main API, go to -[Inference Providers documentation](https://huggingface.co/docs/inference-providers/pricing) for all the details, -and you can generate a Hugging Face access token here: https://huggingface.co/settings/tokens. +To use [HuggingFace](https://huggingface.co/) inference, you'll need to set up an account which will give you [free tier](https://huggingface.co/docs/inference-providers/pricing) allowance on [Inference Providers](https://huggingface.co/docs/inference-providers). To setup inference, follow these steps: -## Hugging Face access token +1. Go to [Hugging Face](https://huggingface.co/join) and sign up for an account. +2. Create a new access token in [Hugging Face](https://huggingface.co/settings/tokens). +3. Set the `HF_TOKEN` environment variable to the token you just created. Once you have a Hugging Face access token, you can set it as an environment variable: @@ -22,6 +24,8 @@ Once you have a Hugging Face access token, you can set it as an environment vari export HF_TOKEN='hf_token' ``` +## Usage + You can then use [`HuggingFaceModel`][pydantic_ai.models.huggingface.HuggingFaceModel] by name: ```python From 789c26176b8705bd583fa6736c21e50ff1f20d2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?c=C3=A9lina?= Date: Wed, 9 Jul 2025 12:11:05 +0200 Subject: [PATCH 23/27] Update docs/models/huggingface.md --- docs/models/huggingface.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/models/huggingface.md b/docs/models/huggingface.md index a03e6e5a93..61f8eef35f 100644 --- a/docs/models/huggingface.md +++ b/docs/models/huggingface.md @@ -12,7 +12,7 @@ pip/uv-add "pydantic-ai-slim[huggingface]" ## Configuration -To use [HuggingFace](https://huggingface.co/) inference, you'll need to set up an account which will give you [free tier](https://huggingface.co/docs/inference-providers/pricing) allowance on [Inference Providers](https://huggingface.co/docs/inference-providers). To setup inference, follow these steps: +To use [Hugging Face](https://huggingface.co/) inference, you'll need to set up an account which will give you [free tier](https://huggingface.co/docs/inference-providers/pricing) allowance on [Inference Providers](https://huggingface.co/docs/inference-providers). To setup inference, follow these steps: 1. Go to [Hugging Face](https://huggingface.co/join) and sign up for an account. 2. Create a new access token in [Hugging Face](https://huggingface.co/settings/tokens). From 96a9c9601efbd5d23becd21708ae940d83a38c17 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Wed, 9 Jul 2025 15:38:18 +0200 Subject: [PATCH 24/27] add more tests --- .../pydantic_ai/models/huggingface.py | 2 +- .../test_hf_model_thinking_part.yaml | 291 ++++++++++++++++++ tests/models/test_huggingface.py | 222 ++++++++++++- 3 files changed, 511 insertions(+), 4 deletions(-) create mode 100644 tests/models/cassettes/test_huggingface/test_hf_model_thinking_part.yaml diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 405e991fab..09b0bec21b 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -345,7 +345,7 @@ def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool: }, } ) - if f.strict: + if f.strict is not None: tool_param['function']['strict'] = f.strict return tool_param diff --git a/tests/models/cassettes/test_huggingface/test_hf_model_thinking_part.yaml b/tests/models/cassettes/test_huggingface/test_hf_model_thinking_part.yaml new file mode 100644 index 0000000000..10be947804 --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_hf_model_thinking_part.yaml @@ -0,0 +1,291 @@ +interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/Qwen/Qwen3-235B-A22B?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '470' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"1d6-5wPQfbCXoh8XtBVekhfceCwHN4Y" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 680daa4ac41c05ba341b67d1 + id: Qwen/Qwen3-235B-A22B + inferenceProviderMapping: + fireworks-ai: + providerId: accounts/fireworks/models/qwen3-235b-a22b + status: live + task: conversational + nebius: + providerId: Qwen/Qwen3-235B-A22B + status: live + task: conversational + novita: + providerId: qwen/qwen3-235b-a22b-fp8 + status: live + task: conversational + nscale: + providerId: Qwen/Qwen3-235B-A22B + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '5526' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: "\nOkay, the user is asking how to cross the street safely. Let me break this down step by step. + First, they need to look both ways to check for cars. But wait, should they check left, right, then left again? + I remember that's a common safety tip. They might be in a country where people drive on the right side or the + left, so maybe should I mention that?\n\nAlso, traffic signals and signs are important. What about pedestrian + crossings or traffic lights? Explaining when to walk when the signal is green and the cars have stopped. Oh, right, + sometimes people might not realize to wait for the walk signal. And even when using a crosswalk, you still need + to look both ways because cars might not stop.\n\nDistractions like phones or headphones. Yeah, people often get + hurt because they're looking at their phone while crossing. Should advise them to put away distractions and stay + alert. Kids and elderly folks might need extra care, like holding an adult's hand.\n\nWhat about if there's no + traffic light or crosswalk? Then finding the safest spot with good visibility, maybe near a corner where cars + can see them better. And teaching kids the basics of street safety.\n\nAlso, the confidence aspect—don't rush, + take your time, make eye contact with drivers. And what to do if stuck in the middle? Wait for the next signal. + Oh, and bicycles! In some places, bike lanes cross sidewalks, so being watchful for cyclists too.\n\nWait, should + I structure these points in a numbered list? Start with stopping at the curb, then looking both ways, checking + traffic signals, obeying signs, avoiding distractions, using crosswalks if possible, teaching kids, staying visible, + making eye contact, and what to do if stuck. Maybe add something about not assuming drivers see them and being + cautious.\n\nLet me make sure not to miss any key points. Also, mention that it's safer to cross at intersections. + And maybe a final note about local laws or practices varying by country. Yeah, that covers the main points. I + should present it clearly so it's easy to follow step by step without getting overwhelmed.\n\n\nCrossing + the street safely requires attention, patience, and following key steps. Here's a clear guide:\n\n1. **Stop at + the Curb**: Find a safe spot to pause before stepping onto the road.\n\n2. **Look Both Ways (Left, Right, Then + Left Again!)** \n - **First check left**: Look for oncoming traffic from your left (if driving is on the right + side in your country). \n - **Then check right**: Check for vehicles coming from the right. \n - **Final + glance left**: Recheck the direction of traffic closest to you before stepping off the curb. \n *(Reverse this + order if driving is on the left, as in the UK or Japan.)*\n\n3. **Use Traffic Signals and Crosswalks**: \n - + Wait for the pedestrian \"walk\" signal (green hand or similar). \n - If there’s no signal, only cross once + all vehicles have come to a complete stop and you’ve made eye contact with drivers. \n - Follow any painted + crosswalk lines and stay within them.\n\n4. **Obey Traffic Signs/Lights**: \n - Red/yellow lights mean stop. + Green means it’s safe to start crossing, but still watch for turning vehicles. \n - If the \"don’t walk\" signal + flashes while you’re mid-crossing, finish crossing without rushing.\n\n5. **Avoid Distractions**: \n - Put + away phones, earbuds, or anything that blocks your senses. \n - Keep your head up and stay alert to your surroundings.\n\n6. + **Be Visible and Predictable**: \n - Wear bright/light-colored clothing, especially at night. \n - Walk + (don’t run) and follow the flow of traffic. Avoid sudden changes in direction.\n\n7. **Teach Children Safely**: + \ \n - Hold young children’s hands. \n - Practice the \"stop, look, listen\" rule together. \n - Teach + them to make eye contact with drivers before crossing.\n\n8. **Cross at Intersections When Possible**: \n - + Drivers expect pedestrians at crosswalks and intersections. \n - If no crosswalk exists, choose a spot with + clear visibility (e.g., where you can see around parked cars).\n\n9. **Don’t Assume Drivers See You**: \n - + Even if a car stops, check for other vehicles that might not yield. \n - At night, use a flashlight or phone + light to stay visible.\n\n10. **What to Do if Stuck Mid-Street**: \n - If the light changes before you reach + the other side, stay calm. \n - Stop at the median or safety island and wait for the next signal. \n\n**Bonus + Tip**: In areas with bike lanes, check for cyclists even once you’ve started crossing. In some places, bikes ride + against traffic flow, so look both ways even on one-way streets.\n\n**Local Laws Matter**: Check rules in your + area—e.g., some places require yielding to pedestrians, while others prioritize drivers. Always prioritize your + safety over assumptions.\n\nFollow these steps, and you’ll cross the street confidently and safely every time! + \U0001F6B6♀️ ✅" + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: null + created: 1752067065 + id: chatcmpl-957db61fe60d4440bcfe1f11f2c5b4b9 + model: Qwen/Qwen3-235B-A22B + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 1090 + completion_tokens_details: null + prompt_tokens: 15 + prompt_tokens_details: null + total_tokens: 1105 + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '9391' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: "\nOkay, the user previously asked how to cross the street, and I gave a detailed answer. Now they're + asking about crossing a river analogously. Let me start by understanding the connection. They want a similar structured + approach but for crossing a river.\n\nFirst, I need to figure out the equivalents between crossing a street and + a river. The original steps included looking both ways, using signals, avoiding distractions, etc. For a river, + physical steps might involve checking the current, choosing a safe spot, maybe using a bridge or boat.\n\nI should + map each street-crossing step to a river scenario. For example, \"stop at the curb\" becomes \"assess the riverbank.\" + Instead of traffic signals, check for ferry schedules or bridge access. Use safety equipment like life jackets + instead of wearing bright clothes.\n\nWait, the user mentioned \"analogously,\" so the structure should mirror + the previous answer but with river-specific actions. Maybe start by pausing to observe the river, checking water + flow instead of traffic. Use bridges as crosswalks and traffic signals. Boating has its own signals, like flags + or lights.\n\nAlso, think about hazards unique to rivers: strong currents, slippery rocks, wildlife. Safety considerations + here would involve knowing how to swim, having a floatation device, not relying on unreliable methods like jumping + on rocks.\n\nNeed to include steps like evaluating the crossing point, checking for bridges or fords, using boats + if necessary, avoiding hazards. Maybe mention time-sensitive factors like weather or flooding, similar to obeying + traffic lights.\n\nI should ensure each point from the previous answer has a parallel. For example, distractions + like phones would be like not paying attention to the river's flow. Visibility could mean wearing bright colors + to be seen on the water.\n\nAlso, consider group scenarios or children, making sure everyone knows the plan. Teaching + kids to cross safely here might involve knowing how to use floatation devices.\n\nI need to make sure the answer + is comprehensive but clear, following the same logical structure as the street crossing answer. Maybe list steps + in order, similar to the original list. Check for completeness: assessment, choosing the method, using proper + equipment, following safety protocols, dealing with emergencies mid-crossing, etc.\n\nWait, the original answer + had 10 points plus a bonus tip. Should mirror that structure? Maybe create a list with parallels. Also, ensure + that analogies are accurate—traffic becomes currents, crosswalks become bridges or fords, traffic signals become + navigational markers or ferry schedules.\n\nAvoiding distractions would involve focusing on the water's movement + and surroundings, not looking at devices. Visibility in water might involve reflective gear or signaling.\n\nAlso, + legal aspects: crossing at authorized spots only, checking if permits are needed. Local hazards or regulations + could be like local traffic laws.\n\nPotential mistakes: Maybe comparing the river's current to traffic is too + direct. Or missing steps like testing the stability of a path across the river, which isn't present in street + crossing. Need to add those unique elements naturally.\n\nAlso, emergencies: If mid-river and something goes wrong, + similar to being stuck on the street. Having a plan, like heading to a safer area or using safety equipment.\n\nAlright, + structure each point as an analogy, then expand with river-specific advice. Make sure the order parallels the + original. Use similar section titles but adapted. Maybe end with a note on local conditions and safety first.\n\n\nCrossing + a river safely shares many principles with crossing a street but involves adapting to the unpredictable nature + of water. Below is an analogy-based guide, mirroring the street-crossing steps while addressing river-specific + challenges:\n\n---\n\n### **1. Stop at the Riverbank: Assess the Situation**\n - **Pause and Observe** \n Like + stopping at the curb, stand back from the edge to evaluate the river’s width, speed, and hazards (swift currents, + rocks, debris, depth). \n - **Check for Bridges/Piers/Crossings** \n Just as intersections prioritize + pedestrian safety, bridges or marked fords exist for safer passage. Use them if accessible. \n\n---\n\n### **2. + Read the River: Look Both Upstream and Downstream** \n - **Scan Both Directions** \n Just as you look + left/right for cars, search **upstream (left if driving is right-hand)** for hazards like floating debris or sudden + surges. Check **downstream (right)** for exit points in case you’re swept away. \n - **Check the Flow** \n + \ Assess current strength: Is it a gentle trickle or a raging torrent? Avoid crossing if water is above knee-deep + or too fast. \n\n---\n\n### **3. Use Safe Routes: Bridges, Ferries, or Designated Fords** \n - **Follow Traffic + Signals → Follow Nautical Rules** \n Wait for ferry schedules, flashing lights (if present), or buoys marking + safe paths. Cross only when signals (like a ferry’s horn) indicate it’s safe. \n - **Choose a Footbridge or + Ferry** \n Bridges eliminate water risks entirely, much like crosswalks. Ferries or boats (with licensed + operators) are safest for wider rivers. \n\n---\n\n### **4. Prioritize Your Path: Know Where to Step or Swim** + \ \n - **Identify Stable Rocks or Shallows** \n If wading, pick a route with flat, secure footing (like + stepping stones) or the shallowest stretch, avoiding slippery algae-covered surfaces. \n - **Test the Current** + \ \n Before fully entering, use a stick or rock to gauge the force of the water. Swift currents can sweep + you off your feet faster than a car can strike. \n\n---\n\n### **5. Avoid Distractions: Focus on the Movement** + \ \n - **Put Away Devices** \n A phone distraction here risks losing balance in the river versus stepping + blindly into traffic. Keep both hands free for stability. \n - **Listen to the River** \n Gurgling or + roaring water warns of hidden holes or rapids—similar to hearing a car engine approaching. \n\n---\n\n### **6. + Be Predictable and Visible: Wear Bright Gear or Floats** \n - **Wear a Life Jacket** \n Like high-visibility + clothing, a life jacket keeps you buoyant and makes you easier for rescuers or boat operators to spot. \n - + **Stick to a Straight Route** \n Zigzagging in water wastes energy and increases the risk of losing balance, + just as darting across lanes on a street invites accidents. \n\n---\n\n### **7. Communicate: Make Eye Contact + with Boaters or Guides** \n - **Signal to Operators** \n In small boats or rafts, wave to catch the attention + of passing vessels (like making eye contact with drivers) to ensure they see you. \n - **Use Hand Signals or + Whistles** \n Agree on emergency signals with your group beforehand (e.g., pointing downstream to signal + danger). \n\n---\n\n### **8. Cross at the Safest Spot: Avoid Mid-River Surprises** \n - **Choose Wide, Slow + Sections** \n Like crossing at intersections, wide shallow areas have gentler currents. Avoid narrows where + water funnels into rapids. \n - **Watch for Hidden Dangers** \n Submerged logs, sudden drop-offs, or hypothermic + cold water can be as lethal as a speeding car. \n\n---\n\n### **9. Don’t Assume Safety: Verify Every Step or + Stroke** \n - **Test Each Footstep** \n Tap the riverbed before transferring weight to avoid stepping + into a hole or loose gravel (like checking for icy patches on a sidewalk). \n - **Swim Only If Trained** \n + \ If the river is too deep to wade, only swim if you know how. Use floatation devices if unsure—similar to + holding an adult’s hand as a child crossing a street. \n\n---\n\n### **10. Mid-River Emergencies: What to Do + if Stuck** \n - **If Struck by Current** \n Stay calm, float on your back with feet downstream (to avoid + head-first collisions), and steer toward eddies or the shore. \n - **If Trapped on a Rock** \n Hug a large + rock and wait for help, like pausing at a median. Don’t risk swimming diagonally across the river’s flow. \n\n---\n\n### + **Bonus Tip: Adapt to Local Conditions** \n - **Research Hazards** \n Some rivers have undertows, wildlife, + or pollution. Check local warnings (like road signs for blind corners or school zones). \n - **Weather Watch** + \ \n Sudden rainstorms can cause flash floods—delay crossing if clouds mass on the horizon. \n\n---\n\nBy + applying street-crossing principles to river navigation—patience, situational awareness, and prioritizing safe + infrastructure—you can minimize risks. Always assume the river is more dangerous than it appears, just as you’d + treat an unfamiliar road. **Safety first, crossing second!** \U0001F30A \U0001F6A4 ⚠️" + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: null + created: 1752067094 + id: chatcmpl-35fdec1307634f94a39f7e26f52e12a7 + model: Qwen/Qwen3-235B-A22B + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 1860 + completion_tokens_details: null + prompt_tokens: 691 + prompt_tokens_details: null + total_tokens: 2551 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index ee19e6fff4..90edf9c073 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -2,7 +2,7 @@ import json from collections.abc import Sequence -from dataclasses import dataclass +from dataclasses import asdict, dataclass, field from datetime import datetime, timezone from functools import cached_property from typing import Any, Literal, Union, cast @@ -15,22 +15,26 @@ from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehavior from pydantic_ai.exceptions import ModelHTTPError from pydantic_ai.messages import ( + AudioUrl, BinaryContent, + DocumentUrl, ImageUrl, ModelRequest, ModelResponse, RetryPromptPart, SystemPromptPart, TextPart, + ThinkingPart, ToolCallPart, ToolReturnPart, UserPromptPart, + VideoUrl, ) from pydantic_ai.result import Usage from pydantic_ai.settings import ModelSettings from pydantic_ai.tools import RunContext -from ..conftest import IsDatetime, IsNow, IsStr, raise_if_exception, try_import +from ..conftest import IsDatetime, IsInstance, IsNow, IsStr, raise_if_exception, try_import from .mock_async_stream import MockAsyncStream with try_import() as imports_successful: @@ -70,6 +74,7 @@ class MockHuggingFace: completions: MockChatCompletion | Sequence[MockChatCompletion] | None = None stream: Sequence[MockStreamEvent] | Sequence[Sequence[MockStreamEvent]] | None = None index: int = 0 + chat_completion_kwargs: list[dict[str, Any]] = field(default_factory=list) @cached_property def chat(self) -> Any: @@ -87,8 +92,9 @@ def create_stream_mock( return cast(AsyncInferenceClient, cls(stream=stream)) async def chat_completions_create( - self, *_args: Any, stream: bool = False, **_kwargs: Any + self, *_args: Any, stream: bool = False, **kwargs: Any ) -> ChatCompletionOutput | MockAsyncStream[MockStreamEvent]: + self.chat_completion_kwargs.append(kwargs) if stream or self.stream: assert self.stream is not None, 'you can only use `stream=True` if `stream` is provided' if isinstance(self.stream[0], Sequence): @@ -107,6 +113,13 @@ async def chat_completions_create( return response +def get_mock_chat_completion_kwargs(hf_client: AsyncInferenceClient) -> list[dict[str, Any]]: + if isinstance(hf_client, MockHuggingFace): + return hf_client.chat_completion_kwargs + else: # pragma: no cover + raise RuntimeError('Not a MockHuggingFace instance') + + def completion_message( message: ChatCompletionInputMessage | ChatCompletionOutputMessage, *, usage: ChatCompletionOutputUsage | None = None ) -> ChatCompletionOutput: @@ -745,3 +758,206 @@ async def test_process_response_no_created_timestamp(allow_model_requests: None) response_message = messages[1] assert isinstance(response_message, ModelResponse) assert response_message.timestamp == IsNow(tz=timezone.utc) + + +async def test_retry_prompt_without_tool_name(allow_model_requests: None): + responses = [ + completion_message( + ChatCompletionOutputMessage.parse_obj_as_instance({'content': 'invalid-response', 'role': 'assistant'}) # type: ignore + ), + completion_message( + ChatCompletionOutputMessage.parse_obj_as_instance({'content': 'final-response', 'role': 'assistant'}) # type: ignore + ), + ] + + mock_client = MockHuggingFace.create_mock(responses) + model = HuggingFaceModel( + 'test-model', + provider=HuggingFaceProvider(hf_client=mock_client, api_key='x'), + ) + agent = Agent(model) + + @agent.output_validator + def response_validator(value: str) -> str: + if value == 'invalid-response': + raise ModelRetry('Response is invalid') + return value + + result = await agent.run('Hello') + assert result.output == 'final-response' + assert result.all_messages() == snapshot( + [ + ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), + ModelResponse( + parts=[TextPart(content='invalid-response')], + usage=Usage(requests=1), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='Response is invalid', + tool_name=None, + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='final-response')], + usage=Usage(requests=1), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', + ), + ] + ) + kwargs = get_mock_chat_completion_kwargs(mock_client)[1] + messages = kwargs['messages'] + assert {k: v for k, v in asdict(messages[-2]).items() if v is not None} == { + 'role': 'assistant', + 'content': 'invalid-response', + } + assert {k: v for k, v in asdict(messages[-1]).items() if v is not None} == { + 'role': 'user', + 'content': 'Validation feedback:\nResponse is invalid\n\nFix the errors and try again.', + } + + +async def test_thinking_part_in_history(allow_model_requests: None): + c = completion_message(ChatCompletionOutputMessage(content='response', role='assistant')) # type: ignore + mock_client = MockHuggingFace.create_mock(c) + model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(model) + messages = [ + ModelRequest(parts=[UserPromptPart(content='request')]), + ModelResponse( + parts=[ + TextPart(content='thought 1'), + ThinkingPart(content='this should be ignored'), + TextPart(content='thought 2'), + ], + model_name='hf-model', + timestamp=datetime.now(timezone.utc), + ), + ] + + await agent.run('another request', message_history=messages) + + kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + sent_messages = kwargs['messages'] + assert [{k: v for k, v in asdict(m).items() if v is not None} for m in sent_messages] == snapshot( + [ + {'content': 'request', 'role': 'user'}, + {'content': 'thought 1\n\nthought 2', 'role': 'assistant'}, + {'content': 'another request', 'role': 'user'}, + ] + ) + + +@pytest.mark.parametrize('strict', [True, False, None]) +async def test_tool_strict_mode(allow_model_requests: None, strict: bool | None): + c = completion_message(ChatCompletionOutputMessage(content='response', role='assistant')) # type: ignore + mock_client = MockHuggingFace.create_mock(c) + model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(model) + + @agent.tool_plain(strict=strict) + def my_tool(x: int) -> int: + return x + + await agent.run('hello') + + kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + tools = kwargs['tools'] + if strict is not None: + assert tools[0]['function']['strict'] is strict + else: + assert 'strict' not in tools[0]['function'] + + +@pytest.mark.parametrize( + 'content_item, error_message', + [ + (AudioUrl(url='url'), 'AudioUrl is not supported for Hugging Face'), + (DocumentUrl(url='url'), 'DocumentUrl is not supported for Hugging Face'), + (VideoUrl(url='url'), 'VideoUrl is not supported for Hugging Face'), + ], +) +async def test_unsupported_media_types(allow_model_requests: None, content_item: Any, error_message: str): + model = HuggingFaceModel( + 'Qwen/Qwen2.5-VL-72B-Instruct', + provider=HuggingFaceProvider(api_key='x'), + ) + agent = Agent(model) + + with pytest.raises(NotImplementedError, match=error_message): + await agent.run(['hello', content_item]) + + +@pytest.mark.vcr() +async def test_hf_model_thinking_part(allow_model_requests: None, huggingface_api_key: str): + m = HuggingFaceModel( + 'Qwen/Qwen3-235B-A22B', provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key) + ) + agent = Agent(m) + + result = await agent.run('How do I cross the street?') + assert result.all_messages() == snapshot( + [ + ModelRequest(parts=[UserPromptPart(content='How do I cross the street?', timestamp=IsDatetime())]), + ModelResponse( + parts=[ + IsInstance(ThinkingPart), + IsInstance(TextPart), + ], + usage=Usage(requests=1, request_tokens=15, response_tokens=1090, total_tokens=1105), + model_name='Qwen/Qwen3-235B-A22B', + timestamp=IsDatetime(), + vendor_id='chatcmpl-957db61fe60d4440bcfe1f11f2c5b4b9', + ), + ] + ) + + result = await agent.run( + 'Considering the way to cross the street, analogously, how do I cross the river?', + model=HuggingFaceModel( + 'Qwen/Qwen3-235B-A22B', provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key) + ), + message_history=result.all_messages(), + ) + assert result.all_messages() == snapshot( + [ + ModelRequest(parts=[UserPromptPart(content='How do I cross the street?', timestamp=IsDatetime())]), + ModelResponse( + parts=[ + IsInstance(ThinkingPart), + IsInstance(TextPart), + ], + usage=Usage(requests=1, request_tokens=15, response_tokens=1090, total_tokens=1105), + model_name='Qwen/Qwen3-235B-A22B', + timestamp=IsDatetime(), + vendor_id='chatcmpl-957db61fe60d4440bcfe1f11f2c5b4b9', + ), + ModelRequest( + parts=[ + UserPromptPart( + content='Considering the way to cross the street, analogously, how do I cross the river?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + IsInstance(ThinkingPart), + TextPart(content=IsStr()), + ], + usage=Usage(requests=1, request_tokens=691, response_tokens=1860, total_tokens=2551), + model_name='Qwen/Qwen3-235B-A22B', + timestamp=IsDatetime(), + vendor_id='chatcmpl-35fdec1307634f94a39f7e26f52e12a7', + ), + ] + ) From 60e5c74af208bf6cfcf623bb4a5adf4f526db994 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Wed, 9 Jul 2025 17:21:10 +0200 Subject: [PATCH 25/27] coverage --- tests/models/test_huggingface.py | 52 +++++++++++++++++++++++++++----- 1 file changed, 44 insertions(+), 8 deletions(-) diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index 90edf9c073..bc6a7a359d 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -38,6 +38,7 @@ from .mock_async_stream import MockAsyncStream with try_import() as imports_successful: + import aiohttp from huggingface_hub import ( AsyncInferenceClient, ChatCompletionInputMessage, @@ -242,6 +243,22 @@ async def test_stream_completion(allow_model_requests: None): assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) +async def test_multiple_stream_calls(allow_model_requests: None): + stream = [ + [text_chunk('first '), text_chunk('call', finish_reason='stop')], + [text_chunk('second '), text_chunk('call', finish_reason='stop')], + ] + mock_client = MockHuggingFace.create_stream_mock(stream) + model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(model) + + async with agent.run_stream('first') as result: + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['first ', 'first call']) + + async with agent.run_stream('second') as result: + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['second ', 'second call']) + + async def test_request_tool_call(allow_model_requests: None): tool_call_1 = ChatCompletionOutputToolCall.parse_obj_as_instance( # type:ignore { @@ -720,11 +737,6 @@ def test_system_property(): async def test_model_client_response_error(allow_model_requests: None) -> None: - try: - import aiohttp - except ImportError: - pytest.skip('aiohttp is not installed') - request_info = Mock(spec=aiohttp.RequestInfo) request_info.url = 'http://test.com' request_info.method = 'POST' @@ -859,8 +871,31 @@ async def test_thinking_part_in_history(allow_model_requests: None): @pytest.mark.parametrize('strict', [True, False, None]) async def test_tool_strict_mode(allow_model_requests: None, strict: bool | None): - c = completion_message(ChatCompletionOutputMessage(content='response', role='assistant')) # type: ignore - mock_client = MockHuggingFace.create_mock(c) + tool_call = ChatCompletionOutputToolCall.parse_obj_as_instance( # type:ignore + { + 'function': ChatCompletionOutputFunctionDefinition.parse_obj_as_instance( # type:ignore + { + 'name': 'my_tool', + 'arguments': '{"x": 42}', + } + ), + 'id': '1', + 'type': 'function', + } + ) + responses = [ + completion_message( + ChatCompletionOutputMessage.parse_obj_as_instance( # type:ignore + { + 'content': None, + 'role': 'assistant', + 'tool_calls': [tool_call], + } + ) + ), + completion_message(ChatCompletionOutputMessage(content='final response', role='assistant')), # type: ignore + ] + mock_client = MockHuggingFace.create_mock(responses) model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) agent = Agent(model) @@ -868,7 +903,8 @@ async def test_tool_strict_mode(allow_model_requests: None, strict: bool | None) def my_tool(x: int) -> int: return x - await agent.run('hello') + result = await agent.run('hello') + assert result.output == 'final response' kwargs = get_mock_chat_completion_kwargs(mock_client)[0] tools = kwargs['tools'] From f2a74dd4ed00a9b4b42dde6ba7efd4f9b4f1a7ae Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Wed, 9 Jul 2025 17:38:45 +0200 Subject: [PATCH 26/27] remove no-cover --- pydantic_ai_slim/pydantic_ai/models/huggingface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 09b0bec21b..a569fd2f6a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -405,7 +405,7 @@ async def _map_user_prompt(part: UserPromptPart) -> ChatCompletionInputMessage: raise NotImplementedError('AudioUrl is not supported for Hugging Face') elif isinstance(item, DocumentUrl): raise NotImplementedError('DocumentUrl is not supported for Hugging Face') - elif isinstance(item, VideoUrl): # pragma: no cover + elif isinstance(item, VideoUrl): raise NotImplementedError('VideoUrl is not supported for Hugging Face') else: assert_never(item) From 9539909c8a3d0f5d4c511194a534ed9d54350df2 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 16 Jul 2025 16:48:15 +0200 Subject: [PATCH 27/27] replace exception handling --- .../pydantic_ai/models/huggingface.py | 33 ++++++++----------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index a569fd2f6a..41d53ca62a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -49,7 +49,6 @@ ChatCompletionOutput, ChatCompletionOutputMessage, ChatCompletionStreamOutput, - InferenceTimeoutError, ) from huggingface_hub.errors import HfHubHTTPError @@ -88,11 +87,9 @@ class HuggingFaceModelSettings(ModelSettings, total=False): - """Settings used for a Hugging Face model request. - - ALL FIELDS MUST BE `huggingface_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. - """ + """Settings used for a Hugging Face model request.""" + # ALL FIELDS MUST BE `huggingface_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. # This class is a placeholder for any future huggingface-specific settings @@ -220,20 +217,18 @@ async def _completions_create( top_logprobs=model_settings.get('top_logprobs', None), extra_body=model_settings.get('extra_body'), # type: ignore ) - except (InferenceTimeoutError, aiohttp.ClientResponseError, HfHubHTTPError) as e: - if isinstance(e, aiohttp.ClientResponseError): - raise ModelHTTPError( - status_code=e.status, - model_name=self.model_name, - body=e.response_error_payload, # type: ignore - ) from e - elif isinstance(e, HfHubHTTPError): - raise ModelHTTPError( - status_code=e.response.status_code, - model_name=self.model_name, - body=e.response.content, - ) from e - raise # pragma: lax no cover + except aiohttp.ClientResponseError as e: + raise ModelHTTPError( + status_code=e.status, + model_name=self.model_name, + body=e.response_error_payload, # type: ignore + ) from e + except HfHubHTTPError as e: + raise ModelHTTPError( + status_code=e.response.status_code, + model_name=self.model_name, + body=e.response.content, + ) from e def _process_response(self, response: ChatCompletionOutput) -> ModelResponse: """Process a non-streamed response, and prepare a message to return."""