|
41 | 41 | ) |
42 | 42 | from pydantic_ai.models.function import AgentInfo, FunctionModel |
43 | 43 | from pydantic_ai.models.test import TestModel |
44 | | -from pydantic_ai.output import ToolOutput |
| 44 | +from pydantic_ai.output import StructuredDict, ToolOutput |
45 | 45 | from pydantic_ai.profiles import ModelProfile |
46 | 46 | from pydantic_ai.result import Usage |
47 | 47 | from pydantic_ai.tools import ToolDefinition |
@@ -1266,6 +1266,77 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: |
1266 | 1266 | ) |
1267 | 1267 |
|
1268 | 1268 |
|
| 1269 | +def test_output_type_structured_dict(): |
| 1270 | + PersonDict = StructuredDict( |
| 1271 | + { |
| 1272 | + 'type': 'object', |
| 1273 | + 'properties': { |
| 1274 | + 'name': {'type': 'string'}, |
| 1275 | + 'age': {'type': 'integer'}, |
| 1276 | + }, |
| 1277 | + 'required': ['name', 'age'], |
| 1278 | + }, |
| 1279 | + name='Person', |
| 1280 | + description='A person', |
| 1281 | + ) |
| 1282 | + AnimalDict = StructuredDict( |
| 1283 | + { |
| 1284 | + 'type': 'object', |
| 1285 | + 'properties': { |
| 1286 | + 'name': {'type': 'string'}, |
| 1287 | + 'species': {'type': 'string'}, |
| 1288 | + }, |
| 1289 | + 'required': ['name', 'species'], |
| 1290 | + }, |
| 1291 | + name='Animal', |
| 1292 | + description='An animal', |
| 1293 | + ) |
| 1294 | + |
| 1295 | + output_tools = None |
| 1296 | + |
| 1297 | + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: |
| 1298 | + assert info.output_tools is not None |
| 1299 | + |
| 1300 | + nonlocal output_tools |
| 1301 | + output_tools = info.output_tools |
| 1302 | + |
| 1303 | + args_json = '{"name": "John Doe", "age": 30}' |
| 1304 | + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) |
| 1305 | + |
| 1306 | + agent = Agent( |
| 1307 | + FunctionModel(call_tool), |
| 1308 | + output_type=[PersonDict, AnimalDict], |
| 1309 | + ) |
| 1310 | + |
| 1311 | + result = agent.run_sync('Generate a person') |
| 1312 | + |
| 1313 | + assert result.output == snapshot({'name': 'John Doe', 'age': 30}) |
| 1314 | + assert output_tools == snapshot( |
| 1315 | + [ |
| 1316 | + ToolDefinition( |
| 1317 | + name='final_result_Person', |
| 1318 | + parameters_json_schema={ |
| 1319 | + 'properties': {'name': {'type': 'string'}, 'age': {'type': 'integer'}}, |
| 1320 | + 'required': ['name', 'age'], |
| 1321 | + 'title': 'Person', |
| 1322 | + 'type': 'object', |
| 1323 | + }, |
| 1324 | + description='A person', |
| 1325 | + ), |
| 1326 | + ToolDefinition( |
| 1327 | + name='final_result_Animal', |
| 1328 | + parameters_json_schema={ |
| 1329 | + 'properties': {'name': {'type': 'string'}, 'species': {'type': 'string'}}, |
| 1330 | + 'required': ['name', 'species'], |
| 1331 | + 'title': 'Animal', |
| 1332 | + 'type': 'object', |
| 1333 | + }, |
| 1334 | + description='An animal', |
| 1335 | + ), |
| 1336 | + ] |
| 1337 | + ) |
| 1338 | + |
| 1339 | + |
1269 | 1340 | def test_default_structured_output_mode(): |
1270 | 1341 | def hello(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: |
1271 | 1342 | return ModelResponse(parts=[TextPart(content='hello')]) # pragma: no cover |
|
0 commit comments