diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 8f43120d0d..14d3a45284 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -790,12 +790,18 @@ def get_user_agent() -> str: def _customize_tool_def(transformer: type[JsonSchemaTransformer], t: ToolDefinition): schema_transformer = transformer(t.parameters_json_schema, strict=t.strict) parameters_json_schema = schema_transformer.walk() - if t.strict is None: - t = replace(t, strict=schema_transformer.is_strict_compatible) - return replace(t, parameters_json_schema=parameters_json_schema) + return replace( + t, + parameters_json_schema=parameters_json_schema, + strict=schema_transformer.is_strict_compatible if t.strict is None else t.strict, + ) def _customize_output_object(transformer: type[JsonSchemaTransformer], o: OutputObjectDefinition): - schema_transformer = transformer(o.json_schema, strict=True) - son_schema = schema_transformer.walk() - return replace(o, json_schema=son_schema) + schema_transformer = transformer(o.json_schema, strict=o.strict) + json_schema = schema_transformer.walk() + return replace( + o, + json_schema=json_schema, + strict=schema_transformer.is_strict_compatible if o.strict is None else o.strict, + ) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index d281fbb238..f8605021d3 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -12,7 +12,7 @@ import pytest from dirty_equals import IsListOrTuple from inline_snapshot import snapshot -from pydantic import AnyUrl, BaseModel, Discriminator, Field, Tag +from pydantic import AnyUrl, BaseModel, ConfigDict, Discriminator, Field, Tag from typing_extensions import NotRequired, TypedDict from pydantic_ai import Agent, ModelHTTPError, ModelRetry, UnexpectedModelBehavior @@ -1777,6 +1777,44 @@ class MyModel(BaseModel): ) +def test_native_output_strict_mode(allow_model_requests: None): + class CityLocation(BaseModel): + city: str + country: str + + c = completion_message( + ChatCompletionMessage(content='{"city": "Mexico City", "country": "Mexico"}', role='assistant'), + ) + mock_client = MockOpenAI.create_mock(c) + model = OpenAIModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + + # Explicit strict=True + agent = Agent(model, output_type=NativeOutput(CityLocation, strict=True)) + + agent.run_sync('What is the capital of Mexico?') + assert get_mock_chat_completion_kwargs(mock_client)[-1]['response_format']['json_schema']['strict'] is True + + # Explicit strict=False + agent = Agent(model, output_type=NativeOutput(CityLocation, strict=False)) + + agent.run_sync('What is the capital of Mexico?') + assert get_mock_chat_completion_kwargs(mock_client)[-1]['response_format']['json_schema']['strict'] is False + + # Strict-compatible + agent = Agent(model, output_type=NativeOutput(CityLocation)) + + agent.run_sync('What is the capital of Mexico?') + assert get_mock_chat_completion_kwargs(mock_client)[-1]['response_format']['json_schema']['strict'] is True + + # Strict-incompatible + CityLocation.model_config = ConfigDict(extra='allow') + + agent = Agent(model, output_type=NativeOutput(CityLocation)) + + agent.run_sync('What is the capital of Mexico?') + assert get_mock_chat_completion_kwargs(mock_client)[-1]['response_format']['json_schema']['strict'] is False + + async def test_openai_instructions(allow_model_requests: None, openai_api_key: str): m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) agent = Agent(m, instructions='You are a helpful assistant.')