diff --git a/docs/api/models/function.md b/docs/api/models/function.md index da173f808d..4049d757c6 100644 --- a/docs/api/models/function.md +++ b/docs/api/models/function.md @@ -36,7 +36,14 @@ async def model_function( print(info) """ AgentInfo( - function_tools=[], allow_text_output=True, output_tools=[], model_settings=None + function_tools=[], + allow_text_output=True, + output_tools=[], + model_settings=None, + model_request_parameters=ModelRequestParameters( + function_tools=[], builtin_tools=[], output_tools=[] + ), + instructions=None, ) """ return ModelResponse(parts=[TextPart('hello world')]) diff --git a/docs/builtin-tools.md b/docs/builtin-tools.md index 2c6a27e97a..1e4bdfd279 100644 --- a/docs/builtin-tools.md +++ b/docs/builtin-tools.md @@ -31,7 +31,7 @@ making it ideal for queries that require up-to-date data. |----------|-----------|-------| | OpenAI Responses | ✅ | Full feature support. To include search results on the [`BuiltinToolReturnPart`][pydantic_ai.messages.BuiltinToolReturnPart] that's available via [`ModelResponse.builtin_tool_calls`][pydantic_ai.messages.ModelResponse.builtin_tool_calls], enable the [`OpenAIResponsesModelSettings.openai_include_web_search_sources`][pydantic_ai.models.openai.OpenAIResponsesModelSettings.openai_include_web_search_sources] [model setting](agents.md#model-run-settings). | | Anthropic | ✅ | Full feature support | -| Google | ✅ | No parameter support. No [`BuiltinToolCallPart`][pydantic_ai.messages.BuiltinToolCallPart] or [`BuiltinToolReturnPart`][pydantic_ai.messages.BuiltinToolReturnPart] is generated when streaming. Using built-in tools and user tools (including [output tools](output.md#tool-output)) at the same time is not supported; to use structured output, use [`PromptedOutput`](output.md#prompted-output) instead. | +| Google | ✅ | No parameter support. No [`BuiltinToolCallPart`][pydantic_ai.messages.BuiltinToolCallPart] or [`BuiltinToolReturnPart`][pydantic_ai.messages.BuiltinToolReturnPart] is generated when streaming. Using built-in tools and function tools (including [output tools](output.md#tool-output)) at the same time is not supported; to use structured output, use [`PromptedOutput`](output.md#prompted-output) instead. | | Groq | ✅ | Limited parameter support. To use web search capabilities with Groq, you need to use the [compound models](https://console.groq.com/docs/compound). | | OpenAI Chat Completions | ❌ | Not supported | | Bedrock | ❌ | Not supported | @@ -123,7 +123,7 @@ in a secure environment, making it perfect for computational tasks, data analysi | Provider | Supported | Notes | |----------|-----------|-------| | OpenAI | ✅ | To include code execution output on the [`BuiltinToolReturnPart`][pydantic_ai.messages.BuiltinToolReturnPart] that's available via [`ModelResponse.builtin_tool_calls`][pydantic_ai.messages.ModelResponse.builtin_tool_calls], enable the [`OpenAIResponsesModelSettings.openai_include_code_execution_outputs`][pydantic_ai.models.openai.OpenAIResponsesModelSettings.openai_include_code_execution_outputs] [model setting](agents.md#model-run-settings). If the code execution generated images, like charts, they will be available on [`ModelResponse.images`][pydantic_ai.messages.ModelResponse.images] as [`BinaryImage`][pydantic_ai.messages.BinaryImage] objects. The generated image can also be used as [image output](output.md#image-output) for the agent run. | -| Google | ✅ | Using built-in tools and user tools (including [output tools](output.md#tool-output)) at the same time is not supported; to use structured output, use [`PromptedOutput`](output.md#prompted-output) instead. | +| Google | ✅ | Using built-in tools and function tools (including [output tools](output.md#tool-output)) at the same time is not supported; to use structured output, use [`PromptedOutput`](output.md#prompted-output) instead. | | Anthropic | ✅ | | | Groq | ❌ | | | Bedrock | ❌ | | @@ -315,7 +315,7 @@ allowing it to pull up-to-date information from the web. | Provider | Supported | Notes | |----------|-----------|-------| -| Google | ✅ | No [`BuiltinToolCallPart`][pydantic_ai.messages.BuiltinToolCallPart] or [`BuiltinToolReturnPart`][pydantic_ai.messages.BuiltinToolReturnPart] is currently generated; please submit an issue if you need this. Using built-in tools and user tools (including [output tools](output.md#tool-output)) at the same time is not supported; to use structured output, use [`PromptedOutput`](output.md#prompted-output) instead. | +| Google | ✅ | No [`BuiltinToolCallPart`][pydantic_ai.messages.BuiltinToolCallPart] or [`BuiltinToolReturnPart`][pydantic_ai.messages.BuiltinToolReturnPart] is currently generated; please submit an issue if you need this. Using built-in tools and function tools (including [output tools](output.md#tool-output)) at the same time is not supported; to use structured output, use [`PromptedOutput`](output.md#prompted-output) instead. | | OpenAI | ❌ | | | Anthropic | ❌ | | | Groq | ❌ | | diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 149e7c97dc..620c0639e3 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -374,9 +374,10 @@ async def _prepare_request_parameters( ) -> models.ModelRequestParameters: """Build tools and create an agent model.""" output_schema = ctx.deps.output_schema - output_object = None - if isinstance(output_schema, _output.NativeOutputSchema): - output_object = output_schema.object_def + + prompted_output_template = ( + output_schema.template if isinstance(output_schema, _output.PromptedOutputSchema) else None + ) function_tools: list[ToolDefinition] = [] output_tools: list[ToolDefinition] = [] @@ -391,7 +392,8 @@ async def _prepare_request_parameters( builtin_tools=ctx.deps.builtin_tools, output_mode=output_schema.mode, output_tools=output_tools, - output_object=output_object, + output_object=output_schema.object_def, + prompted_output_template=prompted_output_template, allow_text_output=output_schema.allows_text, allow_image_output=output_schema.allows_image, ) @@ -489,7 +491,6 @@ async def _prepare_request( message_history = _clean_message_history(message_history) model_request_parameters = await _prepare_request_parameters(ctx) - model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters) model_settings = ctx.deps.model_settings usage = ctx.state.usage @@ -570,7 +571,7 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa # we got an empty response. # this sometimes happens with anthropic (and perhaps other models) # when the model has already returned text along side tool calls - if text_processor := output_schema.text_processor: + if text_processor := output_schema.text_processor: # pragma: no branch # in this scenario, if text responses are allowed, we return text from the most recent model # response, if any for message in reversed(ctx.state.message_history): @@ -584,8 +585,12 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa # not part of the final result output, so we reset the accumulated text text = '' # pragma: no cover if text: - self._next_node = await self._handle_text_response(ctx, text, text_processor) - return + try: + self._next_node = await self._handle_text_response(ctx, text, text_processor) + return + except ToolRetryError: + # If the text from the preview response was invalid, ignore it. + pass # Go back to the model request node with an empty request, which means we'll essentially # resubmit the most recent request that resulted in an empty response, @@ -622,11 +627,11 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa else: assert_never(part) - # At the moment, we prioritize at least executing tool calls if they are present. - # In the future, we'd consider making this configurable at the agent or run level. - # This accounts for cases like anthropic returns that might contain a text response - # and a tool call response, where the text response just indicates the tool call will happen. try: + # At the moment, we prioritize at least executing tool calls if they are present. + # In the future, we'd consider making this configurable at the agent or run level. + # This accounts for cases like anthropic returns that might contain a text response + # and a tool call response, where the text response just indicates the tool call will happen. alternatives: list[str] = [] if tool_calls: async for event in self._handle_tool_calls(ctx, tool_calls): diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 67ee9f3017..ebb737a1cf 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -10,7 +10,7 @@ from pydantic import Json, TypeAdapter, ValidationError from pydantic_core import SchemaValidator, to_json -from typing_extensions import Self, TypedDict, TypeVar, assert_never +from typing_extensions import Self, TypedDict, TypeVar from pydantic_ai._instrumentation import InstrumentationNames @@ -26,7 +26,6 @@ OutputSpec, OutputTypeOrFunction, PromptedOutput, - StructuredOutputMode, TextOutput, TextOutputFunc, ToolOutput, @@ -36,7 +35,7 @@ from .toolsets.abstract import AbstractToolset, ToolsetTool if TYPE_CHECKING: - from .profiles import ModelProfile + pass T = TypeVar('T') """An invariant TypeVar.""" @@ -212,59 +211,30 @@ async def validate( @dataclass(kw_only=True) -class BaseOutputSchema(ABC, Generic[OutputDataT]): +class OutputSchema(ABC, Generic[OutputDataT]): text_processor: BaseOutputProcessor[OutputDataT] | None = None toolset: OutputToolset[Any] | None = None + object_def: OutputObjectDefinition | None = None allows_deferred_tools: bool = False allows_image: bool = False - @abstractmethod - def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: + @property + def mode(self) -> OutputMode: raise NotImplementedError() @property def allows_text(self) -> bool: return self.text_processor is not None - -@dataclass(init=False) -class OutputSchema(BaseOutputSchema[OutputDataT], ABC): - """Model the final output from an agent run.""" - - @classmethod - @overload - def build( - cls, - output_spec: OutputSpec[OutputDataT], - *, - default_mode: StructuredOutputMode, - name: str | None = None, - description: str | None = None, - strict: bool | None = None, - ) -> OutputSchema[OutputDataT]: ... - - @classmethod - @overload - def build( - cls, - output_spec: OutputSpec[OutputDataT], - *, - default_mode: None = None, - name: str | None = None, - description: str | None = None, - strict: bool | None = None, - ) -> BaseOutputSchema[OutputDataT]: ... - @classmethod def build( # noqa: C901 cls, output_spec: OutputSpec[OutputDataT], *, - default_mode: StructuredOutputMode | None = None, name: str | None = None, description: str | None = None, strict: bool | None = None, - ) -> BaseOutputSchema[OutputDataT]: + ) -> OutputSchema[OutputDataT]: """Build an OutputSchema dataclass from an output type.""" outputs = _flatten_output_spec(output_spec) @@ -382,15 +352,12 @@ def build( # noqa: C901 ) if len(other_outputs) > 0: - schema = OutputSchemaWithoutMode( + return AutoOutputSchema( processor=cls._build_processor(other_outputs, name=name, description=description, strict=strict), toolset=toolset, allows_deferred_tools=allows_deferred_tools, allows_image=allows_image, ) - if default_mode: - schema = schema.with_default_mode(default_mode) - return schema if allows_image: return ImageOutputSchema(allows_deferred_tools=allows_deferred_tools) @@ -410,22 +377,9 @@ def _build_processor( return UnionOutputProcessor(outputs=outputs, strict=strict, name=name, description=description) - @property - @abstractmethod - def mode(self) -> OutputMode: - raise NotImplementedError() - - def raise_if_unsupported(self, profile: ModelProfile) -> None: - """Raise an error if the mode is not supported by this model.""" - if self.allows_image and not profile.supports_image_output: - raise UserError('Image output is not supported by this model.') - - def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: - return self - @dataclass(init=False) -class OutputSchemaWithoutMode(BaseOutputSchema[OutputDataT]): +class AutoOutputSchema(OutputSchema[OutputDataT]): processor: BaseObjectOutputProcessor[OutputDataT] def __init__( @@ -439,32 +393,17 @@ def __init__( # At that point we may not know yet what output mode we're going to use if no model was provided or it was deferred until agent.run time, # but we cover ourselves just in case we end up using the tool output mode. super().__init__( - allows_deferred_tools=allows_deferred_tools, toolset=toolset, + object_def=processor.object_def, text_processor=processor, + allows_deferred_tools=allows_deferred_tools, allows_image=allows_image, ) self.processor = processor - def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: - if mode == 'native': - return NativeOutputSchema( - processor=self.processor, - allows_deferred_tools=self.allows_deferred_tools, - allows_image=self.allows_image, - ) - elif mode == 'prompted': - return PromptedOutputSchema( - processor=self.processor, - allows_deferred_tools=self.allows_deferred_tools, - allows_image=self.allows_image, - ) - elif mode == 'tool': - return ToolOutputSchema( - toolset=self.toolset, allows_deferred_tools=self.allows_deferred_tools, allows_image=self.allows_image - ) - else: - assert_never(mode) + @property + def mode(self) -> OutputMode: + return 'auto' @dataclass(init=False) @@ -486,10 +425,6 @@ def __init__( def mode(self) -> OutputMode: return 'text' - def raise_if_unsupported(self, profile: ModelProfile) -> None: - """Raise an error if the mode is not supported by this model.""" - super().raise_if_unsupported(profile) - class ImageOutputSchema(OutputSchema[OutputDataT]): def __init__(self, *, allows_deferred_tools: bool): @@ -499,11 +434,6 @@ def __init__(self, *, allows_deferred_tools: bool): def mode(self) -> OutputMode: return 'image' - def raise_if_unsupported(self, profile: ModelProfile) -> None: - """Raise an error if the mode is not supported by this model.""" - # This already raises if image output is not supported by this model. - super().raise_if_unsupported(profile) - @dataclass(init=False) class StructuredTextOutputSchema(OutputSchema[OutputDataT], ABC): @@ -513,25 +443,19 @@ def __init__( self, *, processor: BaseObjectOutputProcessor[OutputDataT], allows_deferred_tools: bool, allows_image: bool ): super().__init__( - text_processor=processor, allows_deferred_tools=allows_deferred_tools, allows_image=allows_image + text_processor=processor, + object_def=processor.object_def, + allows_deferred_tools=allows_deferred_tools, + allows_image=allows_image, ) self.processor = processor - @property - def object_def(self) -> OutputObjectDefinition: - return self.processor.object_def - class NativeOutputSchema(StructuredTextOutputSchema[OutputDataT]): @property def mode(self) -> OutputMode: return 'native' - def raise_if_unsupported(self, profile: ModelProfile) -> None: - """Raise an error if the mode is not supported by this model.""" - if not profile.supports_json_schema_output: - raise UserError('Native structured output is not supported by this model.') - @dataclass(init=False) class PromptedOutputSchema(StructuredTextOutputSchema[OutputDataT]): @@ -570,14 +494,11 @@ def build_instructions(cls, template: str, object_def: OutputObjectDefinition) - return template.format(schema=json.dumps(schema)) - def raise_if_unsupported(self, profile: ModelProfile) -> None: - """Raise an error if the mode is not supported by this model.""" - super().raise_if_unsupported(profile) - - def instructions(self, default_template: str) -> str: + def instructions(self, default_template: str) -> str: # pragma: no cover """Get instructions to tell model to output JSON matching the schema.""" template = self.template or default_template object_def = self.object_def + assert object_def is not None return self.build_instructions(template, object_def) @@ -602,12 +523,6 @@ def __init__( def mode(self) -> OutputMode: return 'tool' - def raise_if_unsupported(self, profile: ModelProfile) -> None: - """Raise an error if the mode is not supported by this model.""" - super().raise_if_unsupported(profile) - if not profile.supports_tools: - raise UserError('Tool output is not supported by this model.') - class BaseOutputProcessor(ABC, Generic[OutputDataT]): @abstractmethod diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index fd0f59ff38..5bcfa6baae 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -8,7 +8,7 @@ from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager from contextvars import ContextVar -from typing import TYPE_CHECKING, Any, ClassVar, cast, overload +from typing import TYPE_CHECKING, Any, ClassVar, overload from opentelemetry.trace import NoOpTracer, use_span from pydantic.json_schema import GenerateJsonSchema @@ -39,7 +39,6 @@ from ..builtin_tools import AbstractBuiltinTool from ..models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model from ..output import OutputDataT, OutputSpec -from ..profiles import ModelProfile from ..run import AgentRun, AgentRunResult from ..settings import ModelSettings, merge_model_settings from ..tools import ( @@ -133,7 +132,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]): _instrument_default: ClassVar[InstrumentationSettings | bool] = False _deps_type: type[AgentDepsT] = dataclasses.field(repr=False) - _output_schema: _output.BaseOutputSchema[OutputDataT] = dataclasses.field(repr=False) + _output_schema: _output.OutputSchema[OutputDataT] = dataclasses.field(repr=False) _output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] = dataclasses.field(repr=False) _instructions: list[str | _system_prompt.SystemPromptFunc[AgentDepsT]] = dataclasses.field(repr=False) _system_prompts: tuple[str, ...] = dataclasses.field(repr=False) @@ -303,11 +302,7 @@ def __init__( _utils.validate_empty_kwargs(_deprecated_kwargs) - default_output_mode = ( - self.model.profile.default_structured_output_mode if isinstance(self.model, models.Model) else None - ) - - self._output_schema = _output.OutputSchema[OutputDataT].build(output_type, default_mode=default_output_mode) + self._output_schema = _output.OutputSchema[OutputDataT].build(output_type) self._output_validators = [] self._instructions = self._normalize_instructions(instructions) @@ -452,7 +447,7 @@ async def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[Any] | None = None, message_history: Sequence[_messages.ModelMessage] | None = None, deferred_tool_results: DeferredToolResults | None = None, model: models.Model | models.KnownModelName | str | None = None, @@ -549,18 +544,18 @@ async def main(): del model deps = self._get_deps(deps) - output_schema = self._prepare_output_schema(output_type, model_used.profile) + output_schema = self._prepare_output_schema(output_type) output_type_ = output_type or self.output_type # We consider it a user error if a user tries to restrict the result type while having an output validator that # may change the result type from the restricted type to something else. Therefore, we consider the following # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code. - output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators) + output_validators = self._output_validators output_toolset = self._output_toolset if output_schema != self._output_schema or output_validators: - output_toolset = cast(OutputToolset[AgentDepsT], output_schema.toolset) + output_toolset = output_schema.toolset if output_toolset: output_toolset.max_retries = self._max_result_retries output_toolset.output_validators = output_validators @@ -592,11 +587,6 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: *[await func.run(run_context) for func in instructions_functions], ] - model_profile = model_used.profile - if isinstance(output_schema, _output.PromptedOutputSchema): - instructions = output_schema.instructions(model_profile.prompted_output_template) - parts.append(instructions) - parts = [p for p in parts if p] if not parts: return None @@ -609,7 +599,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: instrumentation_settings = None tracer = NoOpTracer() - graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT]( + graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, OutputDataT]( user_deps=deps, prompt=user_prompt, new_message_index=len(message_history) if message_history else 0, @@ -1418,21 +1408,23 @@ def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]: return toolsets + @overload + def _prepare_output_schema(self, output_type: None) -> _output.OutputSchema[OutputDataT]: ... + + @overload def _prepare_output_schema( - self, output_type: OutputSpec[RunOutputDataT] | None, model_profile: ModelProfile - ) -> _output.OutputSchema[RunOutputDataT]: + self, output_type: OutputSpec[RunOutputDataT] + ) -> _output.OutputSchema[RunOutputDataT]: ... + + def _prepare_output_schema(self, output_type: OutputSpec[Any] | None) -> _output.OutputSchema[Any]: if output_type is not None: if self._output_validators: raise exceptions.UserError('Cannot set a custom run `output_type` when the agent has output validators') - schema = _output.OutputSchema[RunOutputDataT].build( - output_type, default_mode=model_profile.default_structured_output_mode - ) + schema = _output.OutputSchema.build(output_type) else: - schema = self._output_schema.with_default_mode(model_profile.default_structured_output_mode) - - schema.raise_if_unsupported(model_profile) + schema = self._output_schema - return schema # pyright: ignore[reportReturnType] + return schema async def __aenter__(self) -> Self: """Enter the agent context. @@ -1502,7 +1494,7 @@ async def run_mcp_servers( @dataclasses.dataclass(init=False) class _AgentFunctionToolset(FunctionToolset[AgentDepsT]): - output_schema: _output.BaseOutputSchema[Any] + output_schema: _output.OutputSchema[Any] def __init__( self, @@ -1510,7 +1502,7 @@ def __init__( *, max_retries: int = 1, id: str | None = None, - output_schema: _output.BaseOutputSchema[Any], + output_schema: _output.OutputSchema[Any], ): self.output_schema = output_schema super().__init__(tools, max_retries=max_retries, id=id) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 2e5ee010a9..d5aaa5e791 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -485,7 +485,7 @@ class BinaryContent: """ _identifier: Annotated[str | None, pydantic.Field(alias='identifier', default=None, exclude=True)] = field( - compare=False, default=None, repr=False + compare=False, default=None ) kind: Literal['binary'] = 'binary' diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index df7ae9b54e..cc4309c96b 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -21,7 +21,7 @@ from .. import _utils from .._json_schema import JsonSchemaTransformer -from .._output import OutputObjectDefinition +from .._output import OutputObjectDefinition, PromptedOutputSchema from .._parts_manager import ModelResponsePartsManager from .._run_context import RunContext from ..builtin_tools import AbstractBuiltinTool @@ -309,6 +309,7 @@ class ModelRequestParameters: output_mode: OutputMode = 'text' output_object: OutputObjectDefinition | None = None output_tools: list[ToolDefinition] = field(default_factory=list) + prompted_output_template: str | None = None allow_text_output: bool = True allow_image_output: bool = False @@ -316,6 +317,12 @@ class ModelRequestParameters: def tool_defs(self) -> dict[str, ToolDefinition]: return {tool_def.name: tool_def for tool_def in [*self.function_tools, *self.output_tools]} + @cached_property + def prompted_output_instructions(self) -> str | None: + if self.output_mode == 'prompted' and self.prompted_output_template and self.output_object: + return PromptedOutputSchema.build_instructions(self.prompted_output_template, self.output_object) + return None + __repr__ = _utils.dataclasses_no_defaults_repr @@ -408,23 +415,52 @@ def prepare_request( ) -> tuple[ModelSettings | None, ModelRequestParameters]: """Prepare request inputs before they are passed to the provider. - This merges the given ``model_settings`` with the model's own ``settings`` attribute and ensures - ``customize_request_parameters`` is applied to the resolved + This merges the given `model_settings` with the model's own `settings` attribute and ensures + `customize_request_parameters` is applied to the resolved [`ModelRequestParameters`][pydantic_ai.models.ModelRequestParameters]. Subclasses can override this method if they need to customize the preparation flow further, but most implementations should simply call - ``self.prepare_request(...)`` at the start of their ``request`` (and related) methods. + `self.prepare_request(...)` at the start of their `request` (and related) methods. """ model_settings = merge_model_settings(self.settings, model_settings) - if builtin_tools := model_request_parameters.builtin_tools: + params = self.customize_request_parameters(model_request_parameters) + + if builtin_tools := params.builtin_tools: # Deduplicate builtin tools - model_request_parameters = replace( - model_request_parameters, + params = replace( + params, builtin_tools=list({tool.unique_id: tool for tool in builtin_tools}.values()), ) - model_request_parameters = self.customize_request_parameters(model_request_parameters) - return model_settings, model_request_parameters + if params.output_mode == 'auto': + output_mode = self.profile.default_structured_output_mode + params = replace( + params, + output_mode=output_mode, + allow_text_output=output_mode in ('native', 'prompted'), + ) + + # Reset irrelevant fields + if params.output_tools and params.output_mode != 'tool': + params = replace(params, output_tools=[]) + if params.output_object and params.output_mode not in ('native', 'prompted'): + params = replace(params, output_object=None) + if params.prompted_output_template and params.output_mode != 'prompted': + params = replace(params, prompted_output_template=None) # pragma: no cover + + # Set default prompted output template + if params.output_mode == 'prompted' and not params.prompted_output_template: + params = replace(params, prompted_output_template=self.profile.prompted_output_template) + + # Check if output mode is supported + if params.output_mode == 'native' and not self.profile.supports_json_schema_output: + raise UserError('Native structured output is not supported by this model.') + if params.output_mode == 'tool' and not self.profile.supports_tools: + raise UserError('Tool output is not supported by this model.') + if params.allow_image_output and not self.profile.supports_image_output: + raise UserError('Image output is not supported by this model.') + + return model_settings, params @property @abstractmethod @@ -462,13 +498,17 @@ def base_url(self) -> str | None: return None @staticmethod - def _get_instructions(messages: list[ModelMessage]) -> str | None: + def _get_instructions( + messages: list[ModelMessage], model_request_parameters: ModelRequestParameters | None = None + ) -> str | None: """Get instructions from the first ModelRequest found when iterating messages in reverse. In the case that a "mock" request was generated to include a tool-return part for a result tool, we want to use the instructions from the second-to-most-recent request (which should correspond to the original request that generated the response that resulted in the tool-return part). """ + instructions = None + last_two_requests: list[ModelRequest] = [] for message in reversed(messages): if isinstance(message, ModelRequest): @@ -476,33 +516,38 @@ def _get_instructions(messages: list[ModelMessage]) -> str | None: if len(last_two_requests) == 2: break if message.instructions is not None: - return message.instructions + instructions = message.instructions + break # If we don't have two requests, and we didn't already return instructions, there are definitely not any: - if len(last_two_requests) != 2: - return None - - most_recent_request = last_two_requests[0] - second_most_recent_request = last_two_requests[1] - - # If we've gotten this far and the most recent request consists of only tool-return parts or retry-prompt parts, - # we use the instructions from the second-to-most-recent request. This is necessary because when handling - # result tools, we generate a "mock" ModelRequest with a tool-return part for it, and that ModelRequest will not - # have the relevant instructions from the agent. - - # While it's possible that you could have a message history where the most recent request has only tool returns, - # I believe there is no way to achieve that would _change_ the instructions without manually crafting the most - # recent message. That might make sense in principle for some usage pattern, but it's enough of an edge case - # that I think it's not worth worrying about, since you can work around this by inserting another ModelRequest - # with no parts at all immediately before the request that has the tool calls (that works because we only look - # at the two most recent ModelRequests here). - - # If you have a use case where this causes pain, please open a GitHub issue and we can discuss alternatives. - - if all(p.part_kind == 'tool-return' or p.part_kind == 'retry-prompt' for p in most_recent_request.parts): - return second_most_recent_request.instructions - - return None + if instructions is None and len(last_two_requests) == 2: + most_recent_request = last_two_requests[0] + second_most_recent_request = last_two_requests[1] + + # If we've gotten this far and the most recent request consists of only tool-return parts or retry-prompt parts, + # we use the instructions from the second-to-most-recent request. This is necessary because when handling + # result tools, we generate a "mock" ModelRequest with a tool-return part for it, and that ModelRequest will not + # have the relevant instructions from the agent. + + # While it's possible that you could have a message history where the most recent request has only tool returns, + # I believe there is no way to achieve that would _change_ the instructions without manually crafting the most + # recent message. That might make sense in principle for some usage pattern, but it's enough of an edge case + # that I think it's not worth worrying about, since you can work around this by inserting another ModelRequest + # with no parts at all immediately before the request that has the tool calls (that works because we only look + # at the two most recent ModelRequests here). + + # If you have a use case where this causes pain, please open a GitHub issue and we can discuss alternatives. + + if all(p.part_kind == 'tool-return' or p.part_kind == 'retry-prompt' for p in most_recent_request.parts): + instructions = second_most_recent_request.instructions + + if model_request_parameters and (output_instructions := model_request_parameters.prompted_output_instructions): + if instructions: + instructions = '\n\n'.join([instructions, output_instructions]) + else: + instructions = output_instructions + + return instructions @dataclass diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 80f3bea6e4..31351345b0 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -39,7 +39,7 @@ from ..profiles import ModelProfileSpec from ..providers import Provider, infer_provider from ..providers.anthropic import AsyncAnthropicClient -from ..settings import ModelSettings +from ..settings import ModelSettings, merge_model_settings from ..tools import ToolDefinition from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent @@ -240,6 +240,27 @@ async def request_stream( async with response: yield await self._process_streamed_response(response, model_request_parameters) + def prepare_request( + self, model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters + ) -> tuple[ModelSettings | None, ModelRequestParameters]: + settings = merge_model_settings(self.settings, model_settings) + if ( + model_request_parameters.output_tools + and settings + and (thinking := settings.get('anthropic_thinking')) + and thinking.get('type') == 'enabled' + ): + if model_request_parameters.output_mode == 'auto': + model_request_parameters = replace(model_request_parameters, output_mode='prompted') + elif ( + model_request_parameters.output_mode == 'tool' and not model_request_parameters.allow_text_output + ): # pragma: no branch + # This would result in `tool_choice=required`, which Anthropic does not support with thinking. + raise UserError( + 'Anthropic does not support thinking and output tools at the same time. Use `output_type=PromptedOutput(...)` instead.' + ) + return super().prepare_request(model_settings, model_request_parameters) + @overload async def _messages_create( self, @@ -278,17 +299,13 @@ async def _messages_create( else: if not model_request_parameters.allow_text_output: tool_choice = {'type': 'any'} - if (thinking := model_settings.get('anthropic_thinking')) and thinking.get('type') == 'enabled': - raise UserError( - 'Anthropic does not support thinking and output tools at the same time. Use `output_type=PromptedOutput(...)` instead.' - ) else: tool_choice = {'type': 'auto'} if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None: tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls - system_prompt, anthropic_messages = await self._map_message(messages) + system_prompt, anthropic_messages = await self._map_message(messages, model_request_parameters) try: extra_headers = model_settings.get('extra_headers', {}) @@ -446,7 +463,9 @@ def _add_builtin_tools( ) return tools, mcp_servers, beta_features - async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[BetaMessageParam]]: # noqa: C901 + async def _map_message( # noqa: C901 + self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters + ) -> tuple[str, list[BetaMessageParam]]: """Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`.""" system_prompt_parts: list[str] = [] anthropic_messages: list[BetaMessageParam] = [] @@ -615,7 +634,7 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Be anthropic_messages.append(BetaMessageParam(role='assistant', content=assistant_content_params)) else: assert_never(m) - if instructions := self._get_instructions(messages): + if instructions := self._get_instructions(messages, model_request_parameters): system_prompt_parts.insert(0, instructions) system_prompt = '\n\n'.join(system_prompt_parts) return system_prompt, anthropic_messages diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index 0584158f1d..7e0bc6a009 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -374,7 +374,7 @@ async def _messages_create( model_settings: BedrockModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ConverseResponseTypeDef | ConverseStreamResponseTypeDef: - system_prompt, bedrock_messages = await self._map_messages(messages) + system_prompt, bedrock_messages = await self._map_messages(messages, model_request_parameters) inference_config = self._map_inference_config(model_settings) params: ConverseRequestTypeDef = { @@ -450,7 +450,7 @@ def _map_tool_config(self, model_request_parameters: ModelRequestParameters) -> return tool_config async def _map_messages( # noqa: C901 - self, messages: list[ModelMessage] + self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters ) -> tuple[list[SystemContentBlockTypeDef], list[MessageUnionTypeDef]]: """Maps a `pydantic_ai.Message` to the Bedrock `MessageUnionTypeDef`. @@ -561,7 +561,7 @@ async def _map_messages( # noqa: C901 processed_messages.append(current_message) last_message = cast(dict[str, Any], current_message) - if instructions := self._get_instructions(messages): + if instructions := self._get_instructions(messages, model_request_parameters): system_prompt.insert(0, {'text': instructions}) return system_prompt, processed_messages diff --git a/pydantic_ai_slim/pydantic_ai/models/cohere.py b/pydantic_ai_slim/pydantic_ai/models/cohere.py index b8d7dc9b51..24bb9353c7 100644 --- a/pydantic_ai_slim/pydantic_ai/models/cohere.py +++ b/pydantic_ai_slim/pydantic_ai/models/cohere.py @@ -178,7 +178,7 @@ async def _chat( if model_request_parameters.builtin_tools: raise UserError('Cohere does not support built-in tools') - cohere_messages = self._map_messages(messages) + cohere_messages = self._map_messages(messages, model_request_parameters) try: return await self.client.chat( model=self._model_name, @@ -229,7 +229,9 @@ def _process_response(self, response: V2ChatResponse) -> ModelResponse: provider_details=provider_details, ) - def _map_messages(self, messages: list[ModelMessage]) -> list[ChatMessageV2]: + def _map_messages( + self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters + ) -> list[ChatMessageV2]: """Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`.""" cohere_messages: list[ChatMessageV2] = [] for message in messages: @@ -268,7 +270,7 @@ def _map_messages(self, messages: list[ModelMessage]) -> list[ChatMessageV2]: cohere_messages.append(message_param) else: assert_never(message) - if instructions := self._get_instructions(messages): + if instructions := self._get_instructions(messages, model_request_parameters): cohere_messages.insert(0, SystemChatMessageV2(role='system', content=instructions)) return cohere_messages diff --git a/pydantic_ai_slim/pydantic_ai/models/fallback.py b/pydantic_ai_slim/pydantic_ai/models/fallback.py index c8430f5775..682ab90ea6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/fallback.py +++ b/pydantic_ai_slim/pydantic_ai/models/fallback.py @@ -3,6 +3,7 @@ from collections.abc import AsyncIterator, Callable from contextlib import AsyncExitStack, asynccontextmanager, suppress from dataclasses import dataclass, field +from functools import cached_property from typing import TYPE_CHECKING, Any from opentelemetry.trace import get_current_span @@ -11,6 +12,7 @@ from pydantic_ai.models.instrumented import InstrumentedModel from ..exceptions import FallbackExceptionGroup, ModelHTTPError +from ..profiles import ModelProfile from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model if TYPE_CHECKING: @@ -78,6 +80,7 @@ async def request( for model in self.models: try: + _, prepared_parameters = model.prepare_request(model_settings, model_request_parameters) response = await model.request(messages, model_settings, model_request_parameters) except Exception as exc: if self._fallback_on(exc): @@ -85,7 +88,7 @@ async def request( continue raise exc - self._set_span_attributes(model) + self._set_span_attributes(model, prepared_parameters) return response raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions) @@ -104,6 +107,7 @@ async def request_stream( for model in self.models: async with AsyncExitStack() as stack: try: + _, prepared_parameters = model.prepare_request(model_settings, model_request_parameters) response = await stack.enter_async_context( model.request_stream(messages, model_settings, model_request_parameters, run_context) ) @@ -113,19 +117,36 @@ async def request_stream( continue raise exc # pragma: no cover - self._set_span_attributes(model) + self._set_span_attributes(model, prepared_parameters) yield response return raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions) - def _set_span_attributes(self, model: Model): + @cached_property + def profile(self) -> ModelProfile: + raise NotImplementedError('FallbackModel does not have its own model profile.') + + def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters: + return model_request_parameters # pragma: no cover + + def prepare_request( + self, model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters + ) -> tuple[ModelSettings | None, ModelRequestParameters]: + return model_settings, model_request_parameters + + def _set_span_attributes(self, model: Model, model_request_parameters: ModelRequestParameters): with suppress(Exception): span = get_current_span() if span.is_recording(): attributes = getattr(span, 'attributes', {}) if attributes.get('gen_ai.request.model') == self.model_name: # pragma: no branch - span.set_attributes(InstrumentedModel.model_attributes(model)) + span.set_attributes( + { + **InstrumentedModel.model_attributes(model), + **InstrumentedModel.model_request_parameters_attributes(model_request_parameters), + } + ) def _default_fallback_condition_factory(exceptions: tuple[type[Exception], ...]) -> Callable[[Exception], bool]: diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 405c088f7d..37876e3e80 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -135,6 +135,8 @@ async def request( allow_text_output=model_request_parameters.allow_text_output, output_tools=model_request_parameters.output_tools, model_settings=model_settings, + model_request_parameters=model_request_parameters, + instructions=self._get_instructions(messages, model_request_parameters), ) assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests' @@ -168,6 +170,8 @@ async def request_stream( allow_text_output=model_request_parameters.allow_text_output, output_tools=model_request_parameters.output_tools, model_settings=model_settings, + model_request_parameters=model_request_parameters, + instructions=self._get_instructions(messages, model_request_parameters), ) assert self.stream_function is not None, ( @@ -216,6 +220,10 @@ class AgentInfo: """The tools that can called to produce the final output of the run.""" model_settings: ModelSettings | None """The model settings passed to the run call.""" + model_request_parameters: ModelRequestParameters + """The model request parameters passed to the run call.""" + instructions: str | None + """The instructions passed to model.""" @dataclass diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 981ef29ef6..afc2bd7156 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -218,7 +218,7 @@ async def _make_request( ) -> AsyncIterator[HTTPResponse]: tools = self._get_tools(model_request_parameters) tool_config = self._get_tool_config(model_request_parameters, tools) - sys_prompt_parts, contents = await self._message_to_gemini_content(messages) + sys_prompt_parts, contents = await self._message_to_gemini_content(messages, model_request_parameters) request_data = _GeminiRequest(contents=contents) if sys_prompt_parts: @@ -331,7 +331,7 @@ async def _process_streamed_response( ) async def _message_to_gemini_content( - self, messages: list[ModelMessage] + self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters ) -> tuple[list[_GeminiTextPart], list[_GeminiContent]]: sys_prompt_parts: list[_GeminiTextPart] = [] contents: list[_GeminiContent] = [] @@ -361,7 +361,7 @@ async def _message_to_gemini_content( contents.append(_content_model_response(m)) else: assert_never(m) - if instructions := self._get_instructions(messages): + if instructions := self._get_instructions(messages, model_request_parameters): sys_prompt_parts.insert(0, _GeminiTextPart(text=instructions)) return sys_prompt_parts, contents diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 4978f90e8b..8a32ef1e66 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -3,7 +3,7 @@ import base64 from collections.abc import AsyncIterator, Awaitable from contextlib import asynccontextmanager -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from datetime import datetime from typing import Any, Literal, cast, overload from uuid import uuid4 @@ -224,6 +224,18 @@ def system(self) -> str: """The model provider.""" return self._provider.name + def prepare_request( + self, model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters + ) -> tuple[ModelSettings | None, ModelRequestParameters]: + if model_request_parameters.builtin_tools and model_request_parameters.output_tools: + if model_request_parameters.output_mode == 'auto': + model_request_parameters = replace(model_request_parameters, output_mode='prompted') + else: + raise UserError( + 'Google does not support output tools and built-in tools at the same time. Use `output_type=PromptedOutput(...)` instead.' + ) + return super().prepare_request(model_settings, model_request_parameters) + async def request( self, messages: list[ModelMessage], @@ -320,12 +332,8 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[T ] if model_request_parameters.builtin_tools: - if model_request_parameters.output_tools: - raise UserError( - 'Gemini does not support output tools and built-in tools at the same time. Use `output_type=PromptedOutput(...)` instead.' - ) if model_request_parameters.function_tools: - raise UserError('Gemini does not support user tools and built-in tools at the same time.') + raise UserError('Google does not support function tools and built-in tools at the same time.') for tool in model_request_parameters.builtin_tools: if isinstance(tool, WebSearchTool): @@ -402,7 +410,7 @@ async def _build_content_and_config( if model_request_parameters.output_mode == 'native': if tools: raise UserError( - 'Gemini does not support `NativeOutput` and tools at the same time. Use `output_type=ToolOutput(...)` instead.' + 'Google does not support `NativeOutput` and tools at the same time. Use `output_type=ToolOutput(...)` instead.' ) response_mime_type = 'application/json' output_object = model_request_parameters.output_object @@ -414,7 +422,7 @@ async def _build_content_and_config( response_mime_type = 'application/json' tool_config = self._get_tool_config(model_request_parameters, tools) - system_instruction, contents = await self._map_messages(messages) + system_instruction, contents = await self._map_messages(messages, model_request_parameters) modalities = [Modality.TEXT.value] if self.profile.supports_image_output: @@ -504,7 +512,9 @@ async def _process_streamed_response( _provider_name=self._provider.name, ) - async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict | None, list[ContentUnionDict]]: + async def _map_messages( + self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters + ) -> tuple[ContentDict | None, list[ContentUnionDict]]: contents: list[ContentUnionDict] = [] system_parts: list[PartDict] = [] @@ -551,7 +561,7 @@ async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict contents.append(_content_model_response(m, self.system)) else: assert_never(m) - if instructions := self._get_instructions(messages): + if instructions := self._get_instructions(messages, model_request_parameters): system_parts.insert(0, {'text': instructions}) system_instruction = ContentDict(role='user', parts=system_parts) if system_parts else None return system_instruction, contents diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index a310b97a69..67c27a19c2 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -272,7 +272,7 @@ async def _completions_create( else: tool_choice = 'auto' - groq_messages = self._map_messages(messages) + groq_messages = self._map_messages(messages, model_request_parameters) response_format: chat.completion_create_params.ResponseFormat | None = None if model_request_parameters.output_mode == 'native': @@ -388,7 +388,9 @@ def _get_builtin_tools( ) return tools - def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]: + def _map_messages( + self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters + ) -> list[chat.ChatCompletionMessageParam]: """Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`.""" groq_messages: list[chat.ChatCompletionMessageParam] = [] for message in messages: @@ -423,7 +425,7 @@ def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletio groq_messages.append(message_param) else: assert_never(message) - if instructions := self._get_instructions(messages): + if instructions := self._get_instructions(messages, model_request_parameters): groq_messages.insert(0, chat.ChatCompletionSystemMessageParam(role='system', content=instructions)) return groq_messages diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index a71edf7026..7ca3199473 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -231,7 +231,7 @@ async def _completions_create( if model_request_parameters.builtin_tools: raise UserError('HuggingFace does not support built-in tools') - hf_messages = await self._map_messages(messages) + hf_messages = await self._map_messages(messages, model_request_parameters) try: return await self.client.chat.completions.create( # type: ignore @@ -322,7 +322,7 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[C return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] async def _map_messages( - self, messages: list[ModelMessage] + self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters ) -> list[ChatCompletionInputMessage | ChatCompletionOutputMessage]: """Just maps a `pydantic_ai.Message` to a `huggingface_hub.ChatCompletionInputMessage`.""" hf_messages: list[ChatCompletionInputMessage | ChatCompletionOutputMessage] = [] @@ -359,7 +359,7 @@ async def _map_messages( hf_messages.append(message_param) else: assert_never(message) - if instructions := self._get_instructions(messages): + if instructions := self._get_instructions(messages, model_request_parameters): hf_messages.insert(0, ChatCompletionInputMessage(content=instructions, role='system')) # type: ignore return hf_messages diff --git a/pydantic_ai_slim/pydantic_ai/models/instrumented.py b/pydantic_ai_slim/pydantic_ai/models/instrumented.py index 84cd23ba80..c3900896b5 100644 --- a/pydantic_ai_slim/pydantic_ai/models/instrumented.py +++ b/pydantic_ai_slim/pydantic_ai/models/instrumented.py @@ -178,17 +178,20 @@ def __init__( description='Monetary cost', ) - def messages_to_otel_events(self, messages: list[ModelMessage]) -> list[Event]: + def messages_to_otel_events( + self, messages: list[ModelMessage], parameters: ModelRequestParameters | None = None + ) -> list[Event]: """Convert a list of model messages to OpenTelemetry events. Args: messages: The messages to convert. + parameters: The model request parameters. Returns: A list of OpenTelemetry events. """ events: list[Event] = [] - instructions = InstrumentedModel._get_instructions(messages) # pyright: ignore [reportPrivateUsage] + instructions = InstrumentedModel._get_instructions(messages, parameters) # pyright: ignore [reportPrivateUsage] if instructions is not None: events.append( Event( @@ -235,10 +238,17 @@ def messages_to_otel_messages(self, messages: list[ModelMessage]) -> list[_otel_ result.append(otel_message) return result - def handle_messages(self, input_messages: list[ModelMessage], response: ModelResponse, system: str, span: Span): + def handle_messages( + self, + input_messages: list[ModelMessage], + response: ModelResponse, + system: str, + span: Span, + parameters: ModelRequestParameters | None = None, + ): if self.version == 1: - events = self.messages_to_otel_events(input_messages) - for event in self.messages_to_otel_events([response]): + events = self.messages_to_otel_events(input_messages, parameters) + for event in self.messages_to_otel_events([response], parameters): events.append( Event( 'gen_ai.choice', @@ -258,7 +268,7 @@ def handle_messages(self, input_messages: list[ModelMessage], response: ModelRes output_messages = self.messages_to_otel_messages([response]) assert len(output_messages) == 1 output_message = output_messages[0] - instructions = InstrumentedModel._get_instructions(input_messages) # pyright: ignore [reportPrivateUsage] + instructions = InstrumentedModel._get_instructions(input_messages, parameters) # pyright: ignore [reportPrivateUsage] system_instructions_attributes = self.system_instructions_attributes(instructions) attributes: dict[str, AttributeValue] = { 'gen_ai.input.messages': json.dumps(self.messages_to_otel_messages(input_messages)), @@ -360,7 +370,7 @@ async def request( ) with self._instrument(messages, prepared_settings, prepared_parameters) as finish: response = await self.wrapped.request(messages, model_settings, model_request_parameters) - finish(response) + finish(response, prepared_parameters) return response @asynccontextmanager @@ -384,7 +394,7 @@ async def request_stream( yield response_stream finally: if response_stream: # pragma: no branch - finish(response_stream.get()) + finish(response_stream.get(), prepared_parameters) @contextmanager def _instrument( @@ -392,7 +402,7 @@ def _instrument( messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, - ) -> Iterator[Callable[[ModelResponse], None]]: + ) -> Iterator[Callable[[ModelResponse, ModelRequestParameters], None]]: operation = 'chat' span_name = f'{operation} {self.model_name}' # TODO Missing attributes: @@ -401,7 +411,7 @@ def _instrument( attributes: dict[str, AttributeValue] = { 'gen_ai.operation.name': operation, **self.model_attributes(self.wrapped), - 'model_request_parameters': json.dumps(InstrumentedModel.serialize_any(model_request_parameters)), + **self.model_request_parameters_attributes(model_request_parameters), 'logfire.json_schema': json.dumps( { 'type': 'object', @@ -419,7 +429,7 @@ def _instrument( try: with self.instrumentation_settings.tracer.start_as_current_span(span_name, attributes=attributes) as span: - def finish(response: ModelResponse): + def finish(response: ModelResponse, parameters: ModelRequestParameters): # FallbackModel updates these span attributes. attributes.update(getattr(span, 'attributes', {})) request_model = attributes[GEN_AI_REQUEST_MODEL_ATTRIBUTE] @@ -443,7 +453,7 @@ def _record_metrics(): if not span.is_recording(): return - self.instrumentation_settings.handle_messages(messages, response, system, span) + self.instrumentation_settings.handle_messages(messages, response, system, span, parameters) attributes_to_set = { **response.usage.opentelemetry_attributes(), @@ -476,7 +486,7 @@ def _record_metrics(): record_metrics() @staticmethod - def model_attributes(model: Model): + def model_attributes(model: Model) -> dict[str, AttributeValue]: attributes: dict[str, AttributeValue] = { GEN_AI_SYSTEM_ATTRIBUTE: model.system, GEN_AI_REQUEST_MODEL_ATTRIBUTE: model.model_name, @@ -494,6 +504,12 @@ def model_attributes(model: Model): return attributes + @staticmethod + def model_request_parameters_attributes( + model_request_parameters: ModelRequestParameters, + ) -> dict[str, AttributeValue]: + return {'model_request_parameters': json.dumps(InstrumentedModel.serialize_any(model_request_parameters))} + @staticmethod def event_to_dict(event: Event) -> dict[str, Any]: if not event.body: diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 90265bbe53..770c8ff6ca 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -230,7 +230,7 @@ async def _completions_create( try: response = await self.client.chat.complete_async( model=str(self._model_name), - messages=self._map_messages(messages), + messages=self._map_messages(messages, model_request_parameters), n=1, tools=self._map_function_and_output_tools_definition(model_request_parameters) or UNSET, tool_choice=self._get_tool_choice(model_request_parameters), @@ -259,7 +259,7 @@ async def _stream_completions_create( ) -> MistralEventStreamAsync[MistralCompletionEvent]: """Create a streaming completion request to the Mistral model.""" response: MistralEventStreamAsync[MistralCompletionEvent] | None - mistral_messages = self._map_messages(messages) + mistral_messages = self._map_messages(messages, model_request_parameters) # TODO(Marcelo): We need to replace the current MistralAI client to use the beta client. # See https://docs.mistral.ai/agents/connectors/websearch/ to support web search. @@ -523,7 +523,9 @@ def _map_user_message(self, message: ModelRequest) -> Iterable[MistralMessages]: else: assert_never(part) - def _map_messages(self, messages: list[ModelMessage]) -> list[MistralMessages]: + def _map_messages( + self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters + ) -> list[MistralMessages]: """Just maps a `pydantic_ai.Message` to a `MistralMessage`.""" mistral_messages: list[MistralMessages] = [] for message in messages: @@ -554,7 +556,7 @@ def _map_messages(self, messages: list[ModelMessage]) -> list[MistralMessages]: mistral_messages.append(MistralAssistantMessage(content=content_chunks, tool_calls=tool_calls)) else: assert_never(message) - if instructions := self._get_instructions(messages): + if instructions := self._get_instructions(messages, model_request_parameters): mistral_messages.insert(0, MistralSystemMessage(content=instructions)) # Post-process messages to insert fake assistant message after tool message if followed by user message diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index c5a57d5b05..ed1e711823 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -477,7 +477,7 @@ async def _completions_create( else: tool_choice = 'auto' - openai_messages = await self._map_messages(messages) + openai_messages = await self._map_messages(messages, model_request_parameters) response_format: chat.completion_create_params.ResponseFormat | None = None if model_request_parameters.output_mode == 'native': @@ -672,7 +672,9 @@ def _get_web_search_options(self, model_request_parameters: ModelRequestParamete f'`{tool.__class__.__name__}` is not supported by `OpenAIChatModel`. If it should be, please file an issue.' ) - async def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]: + async def _map_messages( + self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters + ) -> list[chat.ChatCompletionMessageParam]: """Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`.""" openai_messages: list[chat.ChatCompletionMessageParam] = [] for message in messages: @@ -713,7 +715,7 @@ async def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCom openai_messages.append(message_param) else: assert_never(message) - if instructions := self._get_instructions(messages): + if instructions := self._get_instructions(messages, model_request_parameters): openai_messages.insert(0, chat.ChatCompletionSystemMessageParam(content=instructions, role='system')) return openai_messages @@ -1164,7 +1166,7 @@ async def _responses_create( if previous_response_id == 'auto': previous_response_id, messages = self._get_previous_response_id_and_new_messages(messages) - instructions, openai_messages = await self._map_messages(messages, model_settings) + instructions, openai_messages = await self._map_messages(messages, model_settings, model_request_parameters) reasoning = self._get_reasoning(model_settings) text: responses.ResponseTextConfigParam | None = None @@ -1352,7 +1354,10 @@ def _get_previous_response_id_and_new_messages( return None, messages async def _map_messages( # noqa: C901 - self, messages: list[ModelMessage], model_settings: OpenAIResponsesModelSettings + self, + messages: list[ModelMessage], + model_settings: OpenAIResponsesModelSettings, + model_request_parameters: ModelRequestParameters, ) -> tuple[str | NotGiven, list[responses.ResponseInputItemParam]]: """Just maps a `pydantic_ai.Message` to a `openai.types.responses.ResponseInputParam`.""" profile = OpenAIModelProfile.from_profile(self.profile) @@ -1577,7 +1582,7 @@ async def _map_messages( # noqa: C901 assert_never(item) else: assert_never(message) - instructions = self._get_instructions(messages) or NOT_GIVEN + instructions = self._get_instructions(messages, model_request_parameters) or NOT_GIVEN return instructions, openai_messages def _map_json_schema(self, o: OutputObjectDefinition) -> responses.ResponseFormatTextJSONSchemaConfigParam: diff --git a/pydantic_ai_slim/pydantic_ai/models/outlines.py b/pydantic_ai_slim/pydantic_ai/models/outlines.py index 69d2aecd2b..5b439952c1 100644 --- a/pydantic_ai_slim/pydantic_ai/models/outlines.py +++ b/pydantic_ai_slim/pydantic_ai/models/outlines.py @@ -8,14 +8,13 @@ import io from collections.abc import AsyncIterable, AsyncIterator, Sequence from contextlib import asynccontextmanager -from dataclasses import dataclass +from dataclasses import dataclass, replace from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, Literal, cast from typing_extensions import assert_never from .. import UnexpectedModelBehavior, _utils -from .._output import PromptedOutputSchema from .._run_context import RunContext from .._thinking_part import split_content_into_text_and_thinking from ..exceptions import UserError @@ -247,6 +246,10 @@ async def request( model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) """Make a request to the model.""" prompt, output_type, inference_kwargs = await self._build_generation_arguments( messages, model_settings, model_request_parameters @@ -267,6 +270,11 @@ async def request_stream( model_request_parameters: ModelRequestParameters, run_context: RunContext[Any] | None = None, ) -> AsyncIterator[StreamedResponse]: + model_settings, model_request_parameters = self.prepare_request( + model_settings, + model_request_parameters, + ) + prompt, output_type, inference_kwargs = await self._build_generation_arguments( messages, model_settings, model_request_parameters ) @@ -298,15 +306,11 @@ async def _build_generation_arguments( raise UserError('Outlines does not support function tools and builtin tools yet.') if model_request_parameters.output_object: - instructions = PromptedOutputSchema.build_instructions( - self.profile.prompted_output_template, model_request_parameters.output_object - ) output_type = JsonSchema(model_request_parameters.output_object.json_schema) else: - instructions = None output_type = None - prompt = await self._format_prompt(messages, instructions) + prompt = await self._format_prompt(messages, model_request_parameters) inference_kwargs = self.format_inference_kwargs(model_settings) return prompt, output_type, inference_kwargs @@ -416,17 +420,14 @@ def _format_vllm_offline_inference_kwargs( # pragma: no cover return filtered_settings async def _format_prompt( # noqa: C901 - self, messages: list[ModelMessage], output_format_instructions: str | None + self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters ) -> Chat: """Turn the model messages into an Outlines Chat instance.""" chat = Chat() - if instructions := self._get_instructions(messages): + if instructions := self._get_instructions(messages, model_request_parameters): chat.add_system_message(instructions) - if output_format_instructions: - chat.add_system_message(output_format_instructions) - for message in messages: if isinstance(message, ModelRequest): for part in message.parts: @@ -525,6 +526,14 @@ async def _process_streamed_response( _provider_name='outlines', ) + def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters: + """Customize the model request parameters for the model.""" + if model_request_parameters.output_mode in ('auto', 'native'): + # This way the JSON schema will be included in the instructions. + return replace(model_request_parameters, output_mode='prompted') + else: + return model_request_parameters + @dataclass class OutlinesStreamedResponse(StreamedResponse): diff --git a/pydantic_ai_slim/pydantic_ai/models/wrapper.py b/pydantic_ai_slim/pydantic_ai/models/wrapper.py index 3260cc7d65..5d725eca95 100644 --- a/pydantic_ai_slim/pydantic_ai/models/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/models/wrapper.py @@ -44,7 +44,7 @@ async def request_stream( yield response_stream def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters: - return self.wrapped.customize_request_parameters(model_request_parameters) + return self.wrapped.customize_request_parameters(model_request_parameters) # pragma: no cover def prepare_request( self, diff --git a/pydantic_ai_slim/pydantic_ai/output.py b/pydantic_ai_slim/pydantic_ai/output.py index 27d7f84aea..cd5e5865a6 100644 --- a/pydantic_ai_slim/pydantic_ai/output.py +++ b/pydantic_ai_slim/pydantic_ai/output.py @@ -37,10 +37,11 @@ OutputDataT = TypeVar('OutputDataT', default=str, covariant=True) """Covariant type variable for the output data type of a run.""" -OutputMode = Literal['text', 'tool', 'native', 'prompted', 'tool_or_text', 'image'] +OutputMode = Literal['text', 'tool', 'native', 'prompted', 'tool_or_text', 'image', 'auto'] """All output modes. -`tool_or_text` is deprecated and no longer in use. +- `tool_or_text` is deprecated and no longer in use. +- `auto` means the model will automatically choose a structured output mode based on the model's `ModelProfile.default_structured_output_mode`. """ StructuredOutputMode = Literal['tool', 'native', 'prompted'] """Output modes that can be used for structured output. Used by ModelProfile.default_structured_output_mode""" diff --git a/tests/models/cassettes/test_anthropic/test_anthropic_output_tool_with_thinking.yaml b/tests/models/cassettes/test_anthropic/test_anthropic_output_tool_with_thinking.yaml index 7b04ea09fe..0ea57131be 100644 --- a/tests/models/cassettes/test_anthropic/test_anthropic_output_tool_with_thinking.yaml +++ b/tests/models/cassettes/test_anthropic/test_anthropic_output_tool_with_thinking.yaml @@ -8,7 +8,7 @@ interactions: connection: - keep-alive content-length: - - '471' + - '475' content-type: - application/json host: @@ -23,7 +23,8 @@ interactions: role: user model: claude-sonnet-4-0 stream: false - system: |- + system: |2 + Always respond with a JSON object that's compatible with this schema: {"properties": {"response": {"type": "integer"}}, "required": ["response"], "type": "object", "title": "int"} @@ -38,22 +39,26 @@ interactions: connection: - keep-alive content-length: - - '1032' + - '1150' content-type: - application/json + retry-after: + - '54' strict-transport-security: - max-age=31536000; includeSubDomains; preload transfer-encoding: - chunked parsed_body: content: - - signature: ErQCCkYICBgCKkDId3yuTWB+RmnrHX1N/m+Q+uvt6TTyU6tRWGzFYK1UmQo+lkK5PFjgRvLK6eXA/q8sbVIC6mO3/1eq5aTkSX7+EgzYnRzfZdWJZ1X+410aDC5zyOhlAbOmBiBUmiIwPC3/mI3lVg3woo5Q2jwuZ/u+Pl8LMzrFxG0YbK2F5YDVuCjhrJsOq5e1V36GWJjqKpsBsKiPfPZQ6wizN25g64pwJb+Wjm55hDeGpK8xJeVFuren6PNKKkruBtlK1PIVpjSXBGkdTJCC69xlhwaXF20zah/A8HDm/2QEqid8Gz8+7zu+b7OGa22WdW0uZEwQgJtydTscZFqWzyAm8CZtsCh8STbRNPggOaCNg9vX5ipu2D+jXnAaL6MIOOQ3FUO+CdljS0mvfoyeCabS8TUYAQ== - thinking: The user is asking for 3 + 3, which equals 6. I need to respond with a JSON object that has a "response" - field containing an integer value. + - signature: EuYCCkYICRgCKkBKb+DTJGUMVOQahj61RknYW0QUDawJfq0T0GPDVPY12LCBbS7YPklMKo29mW3gdTAfPBWgYGmOj51p1jkFst2/Egw0xpDI3vnsrcqx484aDB8G93CLqlAq112quyIwq1/wOAOxPiIRklQ/i2iN/UzmWwPrGHmSS+TAq7qh2VQdi32TUk2zVXlOmTJdOSquKs0BbVTmLPWPc7szqedimy5uTbErLKLALr6DH1RRXuvGeRNElsnJofVsDu48aqeZg36g3Pi9Hboj1oE/TpyclCbv9/CWrixeQ/L/GSggr3FxLJvDgpdtppZfRxWajS6DjTH0AOU2aEu1gvxGtrcIa8htRmo5ZwAxISkaiOAm1lY5pSMl31gRFwby3n/2Y32b3UbM4SSlidDCgOTrDtbJSuwygduhfu7OdPg/I737G+sLcB0RUq4rqnPQQ+T+NYuDHPOz5xyGooXi7UNygIrO2BgB + thinking: |- + The user is asking me to calculate 3 + 3, which equals 6. They want me to respond with a JSON object that has a "response" field with an integer value. So I need to return: + + {"response": 6} type: thinking - text: '{"response": 6}' type: text - id: msg_01Fo4JKsQzJMTQBgLSAeCJDG + id: msg_013vH3ViFyo8f85fA4HJuF8A model: claude-sonnet-4-20250514 role: assistant stop_reason: end_turn @@ -65,8 +70,8 @@ interactions: ephemeral_5m_input_tokens: 0 cache_creation_input_tokens: 0 cache_read_input_tokens: 0 - input_tokens: 105 - output_tokens: 55 + input_tokens: 106 + output_tokens: 71 service_tier: standard status: code: 200 diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index e7107c1824..130206e3aa 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -48,6 +48,7 @@ BuiltinToolCallEvent, # pyright: ignore[reportDeprecated] BuiltinToolResultEvent, # pyright: ignore[reportDeprecated] ) +from pydantic_ai.models import ModelRequestParameters from pydantic_ai.output import NativeOutput, PromptedOutput, TextOutput, ToolOutput from pydantic_ai.result import RunUsage from pydantic_ai.settings import ModelSettings @@ -4629,14 +4630,14 @@ async def test_anthropic_empty_content_filtering(env: TestEnv): messages_empty_string: list[ModelMessage] = [ ModelRequest(parts=[UserPromptPart(content='')], kind='request'), ] - _, anthropic_messages = await model._map_message(messages_empty_string) # type: ignore[attr-defined] + _, anthropic_messages = await model._map_message(messages_empty_string, ModelRequestParameters()) # type: ignore[attr-defined] assert anthropic_messages == snapshot([]) # Empty content should be filtered out # Test _map_message with list containing empty strings in user prompt messages_mixed_content: list[ModelMessage] = [ ModelRequest(parts=[UserPromptPart(content=['', 'Hello', '', 'World'])], kind='request'), ] - _, anthropic_messages = await model._map_message(messages_mixed_content) # type: ignore[attr-defined] + _, anthropic_messages = await model._map_message(messages_mixed_content, ModelRequestParameters()) # type: ignore[attr-defined] assert anthropic_messages == snapshot( [{'role': 'user', 'content': [{'text': 'Hello', 'type': 'text'}, {'text': 'World', 'type': 'text'}]}] ) @@ -4647,7 +4648,7 @@ async def test_anthropic_empty_content_filtering(env: TestEnv): ModelResponse(parts=[TextPart(content='')], kind='response'), # Empty response ModelRequest(parts=[UserPromptPart(content='Hello')], kind='request'), ] - _, anthropic_messages = await model._map_message(messages) # type: ignore[attr-defined] + _, anthropic_messages = await model._map_message(messages, ModelRequestParameters()) # type: ignore[attr-defined] # The empty assistant message should be filtered out assert anthropic_messages == snapshot([{'role': 'user', 'content': [{'text': 'Hello', 'type': 'text'}]}]) @@ -4655,7 +4656,7 @@ async def test_anthropic_empty_content_filtering(env: TestEnv): messages_resp: list[ModelMessage] = [ ModelResponse(parts=[TextPart(content=''), TextPart(content='')], kind='response'), ] - _, anthropic_messages = await model._map_message(messages_resp) # type: ignore[attr-defined] + _, anthropic_messages = await model._map_message(messages_resp, ModelRequestParameters()) # type: ignore[attr-defined] assert len(anthropic_messages) == 0 # No messages should be added @@ -4871,14 +4872,7 @@ async def get_user_country() -> str: content='What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.', timestamp=IsDatetime(), ) - ], - instructions="""\ -Always respond with a JSON object that's compatible with this schema: - -{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - -Don't include any text or Markdown fencing before or after.\ -""", + ] ), ModelResponse( parts=[ @@ -4909,14 +4903,7 @@ async def get_user_country() -> str: tool_call_id='toolu_01ArHq5f2wxRpRF2PVQcKExM', timestamp=IsDatetime(), ) - ], - instructions="""\ -Always respond with a JSON object that's compatible with this schema: - -{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - -Don't include any text or Markdown fencing before or after.\ -""", + ] ), ModelResponse( parts=[TextPart(content='{"city": "Mexico City", "country": "Mexico"}')], @@ -4965,14 +4952,7 @@ class CountryLanguage(BaseModel): content='What is the largest city in Mexico?', timestamp=IsDatetime(), ) - ], - instructions="""\ -Always respond with a JSON object that's compatible with this schema: - -{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CityLocation"}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CountryLanguage"}]}}, "required": ["result"], "additionalProperties": false} - -Don't include any text or Markdown fencing before or after.\ -""", + ] ), ModelResponse( parts=[ @@ -5021,7 +5001,7 @@ async def test_anthropic_output_tool_with_thinking(allow_model_requests: None, a settings=AnthropicModelSettings(anthropic_thinking={'type': 'enabled', 'budget_tokens': 3000}), ) - agent = Agent(m, output_type=int) + agent = Agent(m, output_type=ToolOutput(int)) with pytest.raises( UserError, @@ -5031,7 +5011,8 @@ async def test_anthropic_output_tool_with_thinking(allow_model_requests: None, a ): await agent.run('What is 3 + 3?') - agent = Agent(m, output_type=PromptedOutput(int)) + # Will default to prompted output + agent = Agent(m, output_type=int) result = await agent.run('What is 3 + 3?') assert result.output == snapshot(6) diff --git a/tests/models/test_bedrock.py b/tests/models/test_bedrock.py index 7915f4f680..3246c495e0 100644 --- a/tests/models/test_bedrock.py +++ b/tests/models/test_bedrock.py @@ -1120,7 +1120,7 @@ async def test_bedrock_group_consecutive_tool_return_parts(bedrock_provider: Bed ] # Call the mapping function directly - _, bedrock_messages = await model._map_messages(req) # type: ignore[reportPrivateUsage] + _, bedrock_messages = await model._map_messages(req, ModelRequestParameters()) # type: ignore[reportPrivateUsage] assert bedrock_messages == snapshot( [ @@ -1239,7 +1239,7 @@ async def test_bedrock_mistral_tool_result_format(bedrock_provider: BedrockProvi # Models other than Mistral support toolResult.content with text, not json model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider) # Call the mapping function directly - _, bedrock_messages = await model._map_messages(req) # type: ignore[reportPrivateUsage] + _, bedrock_messages = await model._map_messages(req, ModelRequestParameters()) # type: ignore[reportPrivateUsage] assert bedrock_messages == snapshot( [ @@ -1255,7 +1255,7 @@ async def test_bedrock_mistral_tool_result_format(bedrock_provider: BedrockProvi # Mistral requires toolResult.content to hold json, not text model = BedrockConverseModel('mistral.mistral-7b-instruct-v0:2', provider=bedrock_provider) # Call the mapping function directly - _, bedrock_messages = await model._map_messages(req) # type: ignore[reportPrivateUsage] + _, bedrock_messages = await model._map_messages(req, ModelRequestParameters()) # type: ignore[reportPrivateUsage] assert bedrock_messages == snapshot( [ diff --git a/tests/models/test_fallback.py b/tests/models/test_fallback.py index 7ad2a34a6e..d7ee01f481 100644 --- a/tests/models/test_fallback.py +++ b/tests/models/test_fallback.py @@ -4,17 +4,31 @@ import sys from collections.abc import AsyncIterator from datetime import timezone -from typing import Any, cast +from typing import Any, Literal, cast import pytest from _pytest.python_api import RaisesContext from dirty_equals import IsJson from inline_snapshot import snapshot +from pydantic import BaseModel from pydantic_core import to_json -from pydantic_ai import Agent, ModelHTTPError, ModelMessage, ModelRequest, ModelResponse, TextPart, UserPromptPart +from pydantic_ai import ( + Agent, + ModelHTTPError, + ModelMessage, + ModelProfile, + ModelRequest, + ModelResponse, + TextPart, + ToolCallPart, + ToolDefinition, + UserPromptPart, +) +from pydantic_ai.models import ModelRequestParameters from pydantic_ai.models.fallback import FallbackModel from pydantic_ai.models.function import AgentInfo, FunctionModel +from pydantic_ai.output import OutputObjectDefinition from pydantic_ai.settings import ModelSettings from pydantic_ai.usage import RequestUsage @@ -138,6 +152,7 @@ def test_first_failed_instrumented(capfire: CaptureLogfire) -> None: 'output_mode': 'text', 'output_object': None, 'output_tools': [], + 'prompted_output_template': None, 'allow_text_output': True, 'allow_image_output': False, }, @@ -245,6 +260,7 @@ async def test_first_failed_instrumented_stream(capfire: CaptureLogfire) -> None 'output_mode': 'text', 'output_object': None, 'output_tools': [], + 'prompted_output_template': None, 'allow_text_output': True, 'allow_image_output': False, }, @@ -354,6 +370,7 @@ def test_all_failed_instrumented(capfire: CaptureLogfire) -> None: 'output_mode': 'text', 'output_object': None, 'output_tools': [], + 'prompted_output_template': None, 'allow_text_output': True, 'allow_image_output': False, }, @@ -606,3 +623,275 @@ async def return_settings_stream(_: list[ModelMessage], info: AgentInfo): expected = {'extra_headers': {'anthropic-beta': 'context-1m-2025-08-07'}, 'temperature': 0.5} assert json.loads(output) == expected + + +async def test_fallback_model_structured_output(): + class Foo(BaseModel): + bar: str + + def tool_output_func(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + nonlocal enabled_model + if enabled_model != 'tool': + raise ModelHTTPError(status_code=500, model_name='tool-model', body=None) + + assert info.model_request_parameters == snapshot( + ModelRequestParameters( + output_mode='tool', + output_tools=[ + ToolDefinition( + name='final_result', + parameters_json_schema={ + 'properties': {'bar': {'type': 'string'}}, + 'required': ['bar'], + 'title': 'Foo', + 'type': 'object', + }, + description='The final response which ends this conversation', + kind='output', + ) + ], + allow_text_output=False, + ) + ) + + args = Foo(bar='baz').model_dump() + assert info.output_tools + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args)]) + + def native_output_func(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + nonlocal enabled_model + if enabled_model != 'native': + raise ModelHTTPError(status_code=500, model_name='native-model', body=None) + + assert info.model_request_parameters == snapshot( + ModelRequestParameters( + output_mode='native', + output_object=OutputObjectDefinition( + json_schema={ + 'properties': {'bar': {'type': 'string'}}, + 'required': ['bar'], + 'title': 'Foo', + 'type': 'object', + }, + name='Foo', + ), + ) + ) + + text = Foo(bar='baz').model_dump_json() + return ModelResponse(parts=[TextPart(content=text)]) + + def prompted_output_func(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + nonlocal enabled_model + if enabled_model != 'prompted': + raise ModelHTTPError(status_code=500, model_name='prompted-model', body=None) # pragma: no cover + + assert info.model_request_parameters == snapshot( + ModelRequestParameters( + output_mode='prompted', + output_object=OutputObjectDefinition( + json_schema={ + 'properties': {'bar': {'type': 'string'}}, + 'required': ['bar'], + 'title': 'Foo', + 'type': 'object', + }, + name='Foo', + ), + prompted_output_template="""\ + +Always respond with a JSON object that's compatible with this schema: + +{schema} + +Don't include any text or Markdown fencing before or after. +""", + ) + ) + + text = Foo(bar='baz').model_dump_json() + return ModelResponse(parts=[TextPart(content=text)]) + + tool_model = FunctionModel( + tool_output_func, profile=ModelProfile(default_structured_output_mode='tool', supports_tools=True) + ) + native_model = FunctionModel( + native_output_func, + profile=ModelProfile(default_structured_output_mode='native', supports_json_schema_output=True), + ) + prompted_model = FunctionModel( + prompted_output_func, profile=ModelProfile(default_structured_output_mode='prompted') + ) + + fallback_model = FallbackModel(tool_model, native_model, prompted_model) + agent = Agent(fallback_model, output_type=Foo) + + enabled_model: Literal['tool', 'native', 'prompted'] = 'tool' + tool_result = await agent.run('hello') + assert tool_result.output == snapshot(Foo(bar='baz')) + + enabled_model = 'native' + tool_result = await agent.run('hello') + assert tool_result.output == snapshot(Foo(bar='baz')) + + enabled_model = 'prompted' + tool_result = await agent.run('hello') + assert tool_result.output == snapshot(Foo(bar='baz')) + + +@pytest.mark.skipif(not logfire_imports_successful(), reason='logfire not installed') +async def test_fallback_model_structured_output_instrumented(capfire: CaptureLogfire) -> None: + class Foo(BaseModel): + bar: str + + def tool_output_func(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: + raise ModelHTTPError(status_code=500, model_name='tool-model', body=None) + + def prompted_output_func(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.model_request_parameters == snapshot( + ModelRequestParameters( + output_mode='prompted', + output_object=OutputObjectDefinition( + json_schema={ + 'properties': {'bar': {'type': 'string'}}, + 'required': ['bar'], + 'title': 'Foo', + 'type': 'object', + }, + name='Foo', + ), + prompted_output_template="""\ + +Always respond with a JSON object that's compatible with this schema: + +{schema} + +Don't include any text or Markdown fencing before or after. +""", + ) + ) + + text = Foo(bar='baz').model_dump_json() + return ModelResponse(parts=[TextPart(content=text)]) + + tool_model = FunctionModel( + tool_output_func, profile=ModelProfile(default_structured_output_mode='tool', supports_tools=True) + ) + prompted_model = FunctionModel( + prompted_output_func, profile=ModelProfile(default_structured_output_mode='prompted') + ) + fallback_model = FallbackModel(tool_model, prompted_model) + agent = Agent(model=fallback_model, instrument=True, output_type=Foo, instructions='Be kind') + result = await agent.run('hello') + assert result.output == snapshot(Foo(bar='baz')) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='hello', + timestamp=IsNow(tz=timezone.utc), + ) + ], + instructions='Be kind', + ), + ModelResponse( + parts=[TextPart(content='{"bar":"baz"}')], + usage=RequestUsage(input_tokens=51, output_tokens=4), + model_name='function:prompted_output_func:', + timestamp=IsNow(tz=timezone.utc), + ), + ] + ) + assert capfire.exporter.exported_spans_as_dict(parse_json_attributes=True) == snapshot( + [ + { + 'name': 'chat function:prompted_output_func:', + 'context': {'trace_id': 1, 'span_id': 3, 'is_remote': False}, + 'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, + 'start_time': 2000000000, + 'end_time': 3000000000, + 'attributes': { + 'gen_ai.operation.name': 'chat', + 'model_request_parameters': { + 'function_tools': [], + 'builtin_tools': [], + 'output_mode': 'prompted', + 'output_object': { + 'json_schema': { + 'properties': {'bar': {'type': 'string'}}, + 'required': ['bar'], + 'title': 'Foo', + 'type': 'object', + }, + 'name': 'Foo', + 'description': None, + 'strict': None, + }, + 'output_tools': [], + 'prompted_output_template': """\ + +Always respond with a JSON object that's compatible with this schema: + +{schema} + +Don't include any text or Markdown fencing before or after. +""", + 'allow_text_output': True, + 'allow_image_output': False, + }, + 'logfire.span_type': 'span', + 'logfire.msg': 'chat fallback:function:tool_output_func:,function:prompted_output_func:', + 'gen_ai.system': 'function', + 'gen_ai.request.model': 'function:prompted_output_func:', + 'gen_ai.input.messages': [{'role': 'user', 'parts': [{'type': 'text', 'content': 'hello'}]}], + 'gen_ai.output.messages': [ + {'role': 'assistant', 'parts': [{'type': 'text', 'content': '{"bar":"baz"}'}]} + ], + 'gen_ai.system_instructions': [{'type': 'text', 'content': 'Be kind'}], + 'gen_ai.usage.input_tokens': 51, + 'gen_ai.usage.output_tokens': 4, + 'gen_ai.response.model': 'function:prompted_output_func:', + 'logfire.json_schema': { + 'type': 'object', + 'properties': { + 'gen_ai.input.messages': {'type': 'array'}, + 'gen_ai.output.messages': {'type': 'array'}, + 'gen_ai.system_instructions': {'type': 'array'}, + 'model_request_parameters': {'type': 'object'}, + }, + }, + }, + }, + { + 'name': 'agent run', + 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, + 'parent': None, + 'start_time': 1000000000, + 'end_time': 4000000000, + 'attributes': { + 'model_name': 'fallback:function:tool_output_func:,function:prompted_output_func:', + 'agent_name': 'agent', + 'gen_ai.agent.name': 'agent', + 'logfire.msg': 'agent run', + 'logfire.span_type': 'span', + 'gen_ai.usage.input_tokens': 51, + 'gen_ai.usage.output_tokens': 4, + 'pydantic_ai.all_messages': [ + {'role': 'user', 'parts': [{'type': 'text', 'content': 'hello'}]}, + {'role': 'assistant', 'parts': [{'type': 'text', 'content': '{"bar":"baz"}'}]}, + ], + 'final_result': {'bar': 'baz'}, + 'gen_ai.system_instructions': [{'type': 'text', 'content': 'Be kind'}], + 'logfire.json_schema': { + 'type': 'object', + 'properties': { + 'pydantic_ai.all_messages': {'type': 'array'}, + 'gen_ai.system_instructions': {'type': 'array'}, + 'final_result': {'type': 'object'}, + }, + }, + }, + }, + ] + ) diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 37e171ccad..d24189c90b 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -2002,14 +2002,7 @@ class CityLocation(BaseModel): content='What is the largest city in Mexico?', timestamp=IsDatetime(), ) - ], - instructions="""\ -Always respond with a JSON object that's compatible with this schema: - -{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - -Don't include any text or Markdown fencing before or after.\ -""", + ] ), ModelResponse( parts=[ @@ -2056,14 +2049,7 @@ async def get_user_country() -> str: content='What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.', timestamp=IsDatetime(), ) - ], - instructions="""\ -Always respond with a JSON object that's compatible with this schema: - -{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - -Don't include any text or Markdown fencing before or after.\ -""", + ] ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args={}, tool_call_id=IsStr())], @@ -2083,14 +2069,7 @@ async def get_user_country() -> str: tool_call_id=IsStr(), timestamp=IsDatetime(), ) - ], - instructions="""\ -Always respond with a JSON object that's compatible with this schema: - -{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - -Don't include any text or Markdown fencing before or after.\ -""", + ] ), ModelResponse( parts=[TextPart(content='{"city": "Mexico City", "country": "Mexico"}')], @@ -2131,14 +2110,7 @@ class CountryLanguage(BaseModel): content='What is the largest city in Mexico?', timestamp=IsDatetime(), ) - ], - instructions="""\ -Always respond with a JSON object that's compatible with this schema: - -{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CityLocation"}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CountryLanguage"}]}}, "required": ["result"], "additionalProperties": false} - -Don't include any text or Markdown fencing before or after.\ -""", + ] ), ModelResponse( parts=[ diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 6ca78d4484..d560b58c77 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -2323,7 +2323,7 @@ async def get_user_country() -> str: with pytest.raises( UserError, match=re.escape( - 'Gemini does not support `NativeOutput` and tools at the same time. Use `output_type=ToolOutput(...)` instead.' + 'Google does not support `NativeOutput` and tools at the same time. Use `output_type=ToolOutput(...)` instead.' ), ): await agent.run('What is the largest city in the user country?') @@ -2454,14 +2454,7 @@ class CityLocation(BaseModel): content='What is the largest city in Mexico?', timestamp=IsDatetime(), ) - ], - instructions="""\ -Always respond with a JSON object that's compatible with this schema: - -{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - -Don't include any text or Markdown fencing before or after.\ -""", + ] ), ModelResponse( parts=[TextPart(content='{"city": "Mexico City", "country": "Mexico"}')], @@ -2505,14 +2498,7 @@ async def get_user_country() -> str: content='What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.', timestamp=IsDatetime(), ) - ], - instructions="""\ -Always respond with a JSON object that's compatible with this schema: - -{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - -Don't include any text or Markdown fencing before or after.\ -""", + ] ), ModelResponse( parts=[ToolCallPart(tool_name='get_user_country', args={}, tool_call_id=IsStr())], @@ -2534,14 +2520,7 @@ async def get_user_country() -> str: tool_call_id=IsStr(), timestamp=IsDatetime(), ) - ], - instructions="""\ -Always respond with a JSON object that's compatible with this schema: - -{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - -Don't include any text or Markdown fencing before or after.\ -""", + ] ), ModelResponse( parts=[TextPart(content='{"city": "Mexico City", "country": "Mexico"}')], @@ -2583,14 +2562,7 @@ class CountryLanguage(BaseModel): content='What is the largest city in Mexico?', timestamp=IsDatetime(), ) - ], - instructions="""\ -Always respond with a JSON object that's compatible with this schema: - -{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CityLocation"}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CountryLanguage"}]}}, "required": ["result"], "additionalProperties": false} - -Don't include any text or Markdown fencing before or after.\ -""", + ] ), ModelResponse( parts=[ @@ -2713,7 +2685,7 @@ async def get_user_country() -> str: with pytest.raises( UserError, - match=re.escape('Gemini does not support user tools and built-in tools at the same time.'), + match=re.escape('Google does not support function tools and built-in tools at the same time.'), ): await agent.run('What is the largest city in the user country?') @@ -2726,12 +2698,14 @@ class CityLocation(BaseModel): with pytest.raises( UserError, match=re.escape( - 'Gemini does not support output tools and built-in tools at the same time. Use `output_type=PromptedOutput(...)` instead.' + 'Google does not support output tools and built-in tools at the same time. Use `output_type=PromptedOutput(...)` instead.' ), ): await agent.run('What is the largest city in Mexico?') - agent = Agent(m, output_type=PromptedOutput(CityLocation), builtin_tools=[UrlContextTool()]) + # Will default to prompted output + agent = Agent(m, output_type=CityLocation, builtin_tools=[UrlContextTool()]) + result = await agent.run('What is the largest city in Mexico?') assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) @@ -2844,7 +2818,6 @@ async def test_google_image_generation_stream(allow_model_requests: None, google BinaryImage( data=IsBytes(), media_type='image/png', - _identifier='9ff9cc', identifier='9ff9cc', ) ) @@ -2952,7 +2925,6 @@ async def test_google_image_generation_with_text(allow_model_requests: None, goo content=BinaryImage( data=IsBytes(), media_type='image/png', - _identifier='00f2af', identifier=IsStr(), ) ), @@ -2988,7 +2960,6 @@ async def test_google_image_or_text_output(allow_model_requests: None, google_pr BinaryImage( data=IsBytes(), media_type='image/png', - _identifier='f82faf', identifier='f82faf', ) ) @@ -3007,7 +2978,6 @@ async def test_google_image_and_text_output(allow_model_requests: None, google_p BinaryImage( data=IsBytes(), media_type='image/png', - _identifier='67b12f', identifier='67b12f', ) ] diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index 5ce53b251c..928ebd8907 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -5528,14 +5528,7 @@ class CityLocation(BaseModel): content='What is the largest city in Mexico?', timestamp=IsDatetime(), ) - ], - instructions="""\ -Always respond with a JSON object that's compatible with this schema: - -{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - -Don't include any text or Markdown fencing before or after.\ -""", + ] ), ModelResponse( parts=[ diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py index 42033cb3be..8e498188ef 100644 --- a/tests/models/test_instrumented.py +++ b/tests/models/test_instrumented.py @@ -188,6 +188,7 @@ async def test_instrumented_model(capfire: CaptureLogfire): 'output_mode': 'text', 'output_object': None, 'output_tools': [], + 'prompted_output_template': None, 'allow_text_output': True, 'allow_image_output': False, }, @@ -427,6 +428,7 @@ async def test_instrumented_model_stream(capfire: CaptureLogfire): 'output_mode': 'text', 'output_object': None, 'output_tools': [], + 'prompted_output_template': None, 'allow_text_output': True, 'allow_image_output': False, }, @@ -526,6 +528,7 @@ async def test_instrumented_model_stream_break(capfire: CaptureLogfire): 'output_mode': 'text', 'output_object': None, 'output_tools': [], + 'prompted_output_template': None, 'allow_text_output': True, 'allow_image_output': False, }, @@ -645,6 +648,7 @@ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire, instr 'output_mode': 'text', 'output_object': None, 'output_tools': [], + 'prompted_output_template': None, 'allow_text_output': True, 'allow_image_output': False, }, @@ -778,6 +782,7 @@ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire, instr 'output_mode': 'text', 'output_object': None, 'output_tools': [], + 'prompted_output_template': None, 'allow_text_output': True, 'allow_image_output': False, }, @@ -1492,6 +1497,7 @@ async def test_response_cost_error(capfire: CaptureLogfire, monkeypatch: pytest. 'output_mode': 'text', 'output_object': None, 'output_tools': [], + 'prompted_output_template': None, 'allow_text_output': True, 'allow_image_output': False, }, diff --git a/tests/models/test_model_request_parameters.py b/tests/models/test_model_request_parameters.py index c7e87d44c2..1c8d0780e5 100644 --- a/tests/models/test_model_request_parameters.py +++ b/tests/models/test_model_request_parameters.py @@ -32,6 +32,7 @@ def test_model_request_parameters_are_serializable(): 'output_mode': 'text', 'output_object': None, 'output_tools': [], + 'prompted_output_template': None, 'allow_text_output': True, 'allow_image_output': False, } @@ -125,6 +126,7 @@ def test_model_request_parameters_are_serializable(): 'metadata': None, } ], + 'prompted_output_template': None, 'allow_text_output': True, 'allow_image_output': False, } diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 2307b3fd1b..f4d0496966 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -2695,14 +2695,7 @@ async def get_user_country() -> str: content='What is the largest city in the user country?', timestamp=IsDatetime(), ) - ], - instructions="""\ -Always respond with a JSON object that's compatible with this schema: - -{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - -Don't include any text or Markdown fencing before or after.\ -""", + ] ), ModelResponse( parts=[ @@ -2733,14 +2726,7 @@ async def get_user_country() -> str: tool_call_id='call_s7oT9jaLAsEqTgvxZTmFh0wB', timestamp=IsDatetime(), ) - ], - instructions="""\ -Always respond with a JSON object that's compatible with this schema: - -{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - -Don't include any text or Markdown fencing before or after.\ -""", + ] ), ModelResponse( parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], @@ -2793,14 +2779,7 @@ async def get_user_country() -> str: content='What is the largest city in the user country?', timestamp=IsDatetime(), ) - ], - instructions="""\ -Always respond with a JSON object that's compatible with this schema: - -{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CityLocation"}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CountryLanguage"}]}}, "required": ["result"], "additionalProperties": false} - -Don't include any text or Markdown fencing before or after.\ -""", + ] ), ModelResponse( parts=[ @@ -2831,14 +2810,7 @@ async def get_user_country() -> str: tool_call_id='call_wJD14IyJ4KKVtjCrGyNCHO09', timestamp=IsDatetime(), ) - ], - instructions="""\ -Always respond with a JSON object that's compatible with this schema: - -{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CityLocation"}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CountryLanguage"}]}}, "required": ["result"], "additionalProperties": false} - -Don't include any text or Markdown fencing before or after.\ -""", + ] ), ModelResponse( parts=[ diff --git a/tests/models/test_openai_responses.py b/tests/models/test_openai_responses.py index a9a7624b09..8215515276 100644 --- a/tests/models/test_openai_responses.py +++ b/tests/models/test_openai_responses.py @@ -42,6 +42,7 @@ BuiltinToolCallEvent, # pyright: ignore[reportDeprecated] BuiltinToolResultEvent, # pyright: ignore[reportDeprecated] ) +from pydantic_ai.models import ModelRequestParameters from pydantic_ai.output import NativeOutput, PromptedOutput, TextOutput, ToolOutput from pydantic_ai.profiles.openai import openai_model_profile from pydantic_ai.tools import ToolDefinition @@ -1634,14 +1635,7 @@ async def get_user_country() -> str: content='What is the largest city in the user country?', timestamp=IsDatetime(), ) - ], - instructions="""\ -Always respond with a JSON object that's compatible with this schema: - -{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - -Don't include any text or Markdown fencing before or after.\ -""", + ] ), ModelResponse( parts=[ @@ -1668,14 +1662,7 @@ async def get_user_country() -> str: tool_call_id='call_FrlL4M0CbAy8Dhv4VqF1Shom', timestamp=IsDatetime(), ) - ], - instructions="""\ -Always respond with a JSON object that's compatible with this schema: - -{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - -Don't include any text or Markdown fencing before or after.\ -""", + ] ), ModelResponse( parts=[ @@ -1725,14 +1712,7 @@ async def get_user_country() -> str: content='What is the largest city in the user country?', timestamp=IsDatetime(), ) - ], - instructions="""\ -Always respond with a JSON object that's compatible with this schema: - -{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CityLocation"}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CountryLanguage"}]}}, "required": ["result"], "additionalProperties": false} - -Don't include any text or Markdown fencing before or after.\ -""", + ] ), ModelResponse( parts=[ @@ -1759,14 +1739,7 @@ async def get_user_country() -> str: tool_call_id='call_my4OyoVXRT0m7bLWmsxcaCQI', timestamp=IsDatetime(), ) - ], - instructions="""\ -Always respond with a JSON object that's compatible with this schema: - -{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CityLocation"}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CountryLanguage"}]}}, "required": ["result"], "additionalProperties": false} - -Don't include any text or Markdown fencing before or after.\ -""", + ] ), ModelResponse( parts=[ @@ -2375,7 +2348,11 @@ async def test_openai_responses_thinking_without_summary(allow_model_requests: N ] ) - _, openai_messages = await model._map_messages(result.all_messages(), model_settings=model.settings or {}) # type: ignore[reportPrivateUsage] + _, openai_messages = await model._map_messages( # type: ignore[reportPrivateUsage] + result.all_messages(), + model_settings=cast(OpenAIResponsesModelSettings, model.settings or {}), + model_request_parameters=ModelRequestParameters(), + ) assert openai_messages == snapshot( [ {'role': 'user', 'content': 'What is 2+2?'}, @@ -2445,7 +2422,11 @@ async def test_openai_responses_thinking_with_multiple_summaries(allow_model_req ] ) - _, openai_messages = await model._map_messages(result.all_messages(), model_settings=model.settings or {}) # type: ignore[reportPrivateUsage] + _, openai_messages = await model._map_messages( # type: ignore[reportPrivateUsage] + result.all_messages(), + model_settings=cast(OpenAIResponsesModelSettings, model.settings or {}), + model_request_parameters=ModelRequestParameters(), + ) assert openai_messages == snapshot( [ {'role': 'user', 'content': 'What is 2+2?'}, @@ -3573,7 +3554,11 @@ def get_meaning_of_life() -> int: ] ) - _, openai_messages = await model._map_messages(messages, model_settings=model.settings or {}) # type: ignore[reportPrivateUsage] + _, openai_messages = await model._map_messages( # type: ignore[reportPrivateUsage] + messages, + model_settings=cast(OpenAIResponsesModelSettings, model.settings or {}), + model_request_parameters=ModelRequestParameters(), + ) assert openai_messages == snapshot( [ {'role': 'user', 'content': 'What is the meaning of life?'}, @@ -5922,7 +5907,7 @@ class Animal(BaseModel): ModelRequest( parts=[ RetryPromptPart( - content='Please include your response in a tool call.', + content='Please return text or include your response in a tool call.', tool_call_id=IsStr(), timestamp=IsDatetime(), ) @@ -6054,14 +6039,7 @@ class Animal(BaseModel): content='Generate an image of an axolotl.', timestamp=IsDatetime(), ) - ], - instructions="""\ -Always respond with a JSON object that's compatible with this schema: - -{"properties": {"species": {"type": "string"}, "name": {"type": "string"}}, "required": ["species", "name"], "title": "Animal", "type": "object"} - -Don't include any text or Markdown fencing before or after.\ -""", + ] ), ModelResponse( parts=[ @@ -7292,7 +7270,11 @@ def get_meaning_of_life() -> int: result = await agent.run('What is the meaning of life?') messages = result.all_messages() - _, openai_messages = await model._map_messages(messages, model_settings=model.settings or {}) # type: ignore[reportPrivateUsage] + _, openai_messages = await model._map_messages( # type: ignore[reportPrivateUsage] + messages, + model_settings=cast(OpenAIResponsesModelSettings, model.settings or {}), + model_request_parameters=ModelRequestParameters(), + ) assert openai_messages == snapshot( [ {'role': 'user', 'content': 'What is the meaning of life?'}, diff --git a/tests/test_agent.py b/tests/test_agent.py index 8834a427aa..dc7ef42a53 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -55,13 +55,12 @@ OutputSpec, PromptedOutput, TextOutput, - ToolOutputSchema, ) from pydantic_ai.agent import AgentRunResult, WrapperAgent from pydantic_ai.builtin_tools import CodeExecutionTool, MCPServerTool, WebSearchTool from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.output import StructuredDict, ToolOutput +from pydantic_ai.output import OutputObjectDefinition, StructuredDict, ToolOutput from pydantic_ai.result import RunUsage from pydantic_ai.settings import ModelSettings from pydantic_ai.tools import DeferredToolRequests, DeferredToolResults, ToolDefinition, ToolDenied @@ -531,12 +530,12 @@ def test_response_tuple(): m = TestModel() agent = Agent(m, output_type=tuple[str, str]) - assert isinstance(agent._output_schema, ToolOutputSchema) # pyright: ignore[reportPrivateUsage] result = agent.run_sync('Hello') assert result.output == snapshot(('a', 'a')) assert m.last_model_request_parameters is not None + assert m.last_model_request_parameters.output_mode == 'tool' assert m.last_model_request_parameters.function_tools == snapshot([]) assert m.last_model_request_parameters.allow_text_output is False @@ -1637,30 +1636,76 @@ class Node(BaseModel): def test_default_structured_output_mode(): - def hello(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: - return ModelResponse(parts=[TextPart(content='hello')]) # pragma: no cover + class Foo(BaseModel): + bar: str - tool_model = FunctionModel(hello, profile=ModelProfile(default_structured_output_mode='tool')) - native_model = FunctionModel( - hello, + tool_model = TestModel(profile=ModelProfile(default_structured_output_mode='tool')) + native_model = TestModel( profile=ModelProfile(supports_json_schema_output=True, default_structured_output_mode='native'), + custom_output_text=Foo(bar='baz').model_dump_json(), ) - prompted_model = FunctionModel( - hello, + prompted_model = TestModel( profile=ModelProfile(default_structured_output_mode='prompted'), + custom_output_text=Foo(bar='baz').model_dump_json(), ) - class Foo(BaseModel): - bar: str - tool_agent = Agent(tool_model, output_type=Foo) - assert tool_agent._output_schema.mode == 'tool' # type: ignore[reportPrivateUsage] + tool_agent.run_sync('Hello') + assert tool_model.last_model_request_parameters is not None + assert tool_model.last_model_request_parameters.output_mode == 'tool' + assert tool_model.last_model_request_parameters.allow_text_output is False + assert tool_model.last_model_request_parameters.output_object is None + assert tool_model.last_model_request_parameters.output_tools == snapshot( + [ + ToolDefinition( + name='final_result', + parameters_json_schema={ + 'properties': {'bar': {'type': 'string'}}, + 'required': ['bar'], + 'title': 'Foo', + 'type': 'object', + }, + description='The final response which ends this conversation', + kind='output', + ) + ] + ) native_agent = Agent(native_model, output_type=Foo) - assert native_agent._output_schema.mode == 'native' # type: ignore[reportPrivateUsage] + native_agent.run_sync('Hello') + assert native_model.last_model_request_parameters is not None + assert native_model.last_model_request_parameters.output_mode == 'native' + assert native_model.last_model_request_parameters.allow_text_output is True + assert len(native_model.last_model_request_parameters.output_tools) == 0 + assert native_model.last_model_request_parameters.output_object == snapshot( + OutputObjectDefinition( + json_schema={ + 'properties': {'bar': {'type': 'string'}}, + 'required': ['bar'], + 'title': 'Foo', + 'type': 'object', + }, + name='Foo', + ) + ) prompted_agent = Agent(prompted_model, output_type=Foo) - assert prompted_agent._output_schema.mode == 'prompted' # type: ignore[reportPrivateUsage] + prompted_agent.run_sync('Hello') + assert prompted_model.last_model_request_parameters is not None + assert prompted_model.last_model_request_parameters.output_mode == 'prompted' + assert prompted_model.last_model_request_parameters.allow_text_output is True + assert len(prompted_model.last_model_request_parameters.output_tools) == 0 + assert prompted_model.last_model_request_parameters.output_object == snapshot( + OutputObjectDefinition( + json_schema={ + 'properties': {'bar': {'type': 'string'}}, + 'required': ['bar'], + 'title': 'Foo', + 'type': 'object', + }, + name='Foo', + ) + ) def test_prompted_output(): @@ -1691,14 +1736,7 @@ class CityLocation(BaseModel): content='What is the capital of Mexico?', timestamp=IsDatetime(), ) - ], - instructions="""\ -Always respond with a JSON object that's compatible with this schema: - -{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "City & Country", "type": "object", "description": "Description from PromptedOutput. Description from docstring."} - -Don't include any text or Markdown fencing before or after.\ -""", + ] ), ModelResponse( parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], @@ -1732,12 +1770,7 @@ class Foo(BaseModel): content='What is the capital of Mexico?', timestamp=IsDatetime(), ) - ], - instructions="""\ -Gimme some JSON: - -{"properties": {"bar": {"type": "string"}}, "required": ["bar"], "title": "Foo", "type": "object"}\ -""", + ] ), ModelResponse( parts=[TextPart(content='{"bar":"baz"}')], @@ -1800,14 +1833,7 @@ def return_foo_bar(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: content='What is foo?', timestamp=IsDatetime(), ) - ], - instructions="""\ -Always respond with a JSON object that's compatible with this schema: - -{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "FooBar"}, "data": {"properties": {"foo": {"$ref": "#/$defs/Foo"}, "bar": {"$ref": "#/$defs/Bar"}}, "required": ["foo", "bar"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "FooBar", "description": "FooBar description"}, {"type": "object", "properties": {"kind": {"type": "string", "const": "FooBaz"}, "data": {"properties": {"foo": {"$ref": "#/$defs/Foo"}, "baz": {"$ref": "#/$defs/Baz"}}, "required": ["foo", "baz"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "FooBaz", "description": "FooBaz description"}]}}, "required": ["result"], "additionalProperties": false, "$defs": {"Bar": {"description": "Bar description", "properties": {"bar": {"type": "string"}}, "required": ["bar"], "title": "Bar", "type": "object"}, "Foo": {"description": "Foo description", "properties": {"foo": {"type": "string"}}, "required": ["foo"], "title": "Foo", "type": "object"}, "Baz": {"description": "Baz description", "properties": {"baz": {"type": "string"}}, "required": ["baz"], "title": "Baz", "type": "object"}}, "title": "FooBar or FooBaz", "description": "FooBaz description"} - -Don't include any text or Markdown fencing before or after.\ -""", + ] ), ModelResponse( parts=[ @@ -1894,6 +1920,7 @@ class CityLocation(BaseModel): agent = Agent(output_type=NativeOutput(CityLocation, strict=True)) output_schema = agent._output_schema # pyright: ignore[reportPrivateUsage] assert isinstance(output_schema, NativeOutputSchema) + assert output_schema.object_def is not None assert output_schema.object_def.strict @@ -1928,14 +1955,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: content='New York City', timestamp=IsDatetime(), ) - ], - instructions="""\ -Always respond with a JSON object that's compatible with this schema: - -{"additionalProperties": false, "properties": {"city": {"type": "string"}}, "required": ["city"], "type": "object", "title": "get_weather"} - -Don't include any text or Markdown fencing before or after.\ -""", + ] ), ModelResponse( parts=[TextPart(content='{"city": "New York City"}')], @@ -1950,14 +1970,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: tool_call_id=IsStr(), timestamp=IsDatetime(), ) - ], - instructions="""\ -Always respond with a JSON object that's compatible with this schema: - -{"additionalProperties": false, "properties": {"city": {"type": "string"}}, "required": ["city"], "type": "object", "title": "get_weather"} - -Don't include any text or Markdown fencing before or after.\ -""", + ] ), ModelResponse( parts=[TextPart(content='{"city": "Mexico City"}')], @@ -3779,7 +3792,6 @@ def get_image() -> BinaryContent: data=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xf6\x178\x00\x00\x00\x00IEND\xaeB`\x82', media_type='image/png', _identifier='image_id_1', - identifier='image_id_1', ), ], timestamp=IsNow(tz=timezone.utc), diff --git a/tests/test_logfire.py b/tests/test_logfire.py index 6cc57b831a..dadb930dd0 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -553,6 +553,7 @@ async def my_ret(x: int) -> str: 'output_mode': 'text', 'output_tools': [], 'output_object': None, + 'prompted_output_template': None, 'allow_text_output': True, 'allow_image_output': False, } @@ -995,6 +996,7 @@ class MyOutput: 'metadata': None, } ], + 'prompted_output_template': None, 'allow_text_output': False, 'allow_image_output': False, } @@ -1100,6 +1102,7 @@ async def test_feedback(capfire: CaptureLogfire) -> None: 'output_mode': 'text', 'output_object': None, 'output_tools': [], + 'prompted_output_template': None, 'allow_text_output': True, 'allow_image_output': False, },