diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 312a8a2fca..12e6e07fe8 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -659,11 +659,11 @@ async def process_function_tools( # noqa: C901 for call in calls_to_run: yield _messages.FunctionToolCallEvent(call) - user_parts: list[_messages.UserPromptPart] = [] + user_parts_by_index: dict[int, list[_messages.UserPromptPart]] = defaultdict(list) if calls_to_run: # Run all tool tasks in parallel - parts_by_index: dict[int, list[_messages.ModelRequestPart]] = {} + tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {} with ctx.deps.tracer.start_as_current_span( 'running tools', attributes={ @@ -681,15 +681,16 @@ async def process_function_tools( # noqa: C901 done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) for task in done: index = tasks.index(task) - tool_result_part, extra_parts = task.result() - yield _messages.FunctionToolResultEvent(tool_result_part) + tool_part, tool_user_parts = task.result() + yield _messages.FunctionToolResultEvent(tool_part) - parts_by_index[index] = [tool_result_part, *extra_parts] + tool_parts_by_index[index] = tool_part + user_parts_by_index[index] = tool_user_parts # We append the results at the end, rather than as they are received, to retain a consistent ordering # This is mostly just to simplify testing - for k in sorted(parts_by_index): - output_parts.extend(parts_by_index[k]) + for k in sorted(tool_parts_by_index): + output_parts.append(tool_parts_by_index[k]) # Finally, we handle deferred tool calls for call in tool_calls_by_kind['deferred']: @@ -704,7 +705,8 @@ async def process_function_tools( # noqa: C901 else: yield _messages.FunctionToolCallEvent(call) - output_parts.extend(user_parts) + for k in sorted(user_parts_by_index): + output_parts.extend(user_parts_by_index[k]) if final_result: output_final_result.append(final_result) @@ -713,18 +715,18 @@ async def process_function_tools( # noqa: C901 async def _call_function_tool( tool_manager: ToolManager[DepsT], tool_call: _messages.ToolCallPart, -) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, list[_messages.ModelRequestPart]]: +) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, list[_messages.UserPromptPart]]: try: tool_result = await tool_manager.handle_call(tool_call) except ToolRetryError as e: return (e.tool_retry, []) - part = _messages.ToolReturnPart( + tool_part = _messages.ToolReturnPart( tool_name=tool_call.tool_name, content=tool_result, tool_call_id=tool_call.tool_call_id, ) - extra_parts: list[_messages.ModelRequestPart] = [] + user_parts: list[_messages.UserPromptPart] = [] if isinstance(tool_result, _messages.ToolReturn): if ( @@ -740,12 +742,12 @@ async def _call_function_tool( f'Please use `content` instead.' ) - part.content = tool_result.return_value # type: ignore - part.metadata = tool_result.metadata + tool_part.content = tool_result.return_value # type: ignore + tool_part.metadata = tool_result.metadata if tool_result.content: - extra_parts.append( + user_parts.append( _messages.UserPromptPart( - content=list(tool_result.content), + content=tool_result.content, part_kind='user-prompt', ) ) @@ -763,7 +765,7 @@ def process_content(content: Any) -> Any: else: identifier = multi_modal_content_identifier(content.url) - extra_parts.append( + user_parts.append( _messages.UserPromptPart( content=[f'This is file {identifier}:', content], part_kind='user-prompt', @@ -775,11 +777,11 @@ def process_content(content: Any) -> Any: if isinstance(tool_result, list): contents = cast(list[Any], tool_result) - part.content = [process_content(content) for content in contents] + tool_part.content = [process_content(content) for content in contents] else: - part.content = process_content(tool_result) + tool_part.content = process_content(tool_result) - return (part, extra_parts) + return (tool_part, user_parts) @dataclasses.dataclass diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 51e63eea5b..b5d7be2857 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -412,8 +412,8 @@ class ToolReturn: return_value: Any """The return value to be used in the tool response.""" - content: Sequence[UserContent] | None = None - """The content sequence to be sent to the model as a UserPromptPart.""" + content: str | Sequence[UserContent] | None = None + """The content to be sent to the model as a UserPromptPart.""" metadata: Any = None """Additional data that can be accessed programmatically by the application but is not sent to the LLM.""" diff --git a/tests/test_tools.py b/tests/test_tools.py index e6a21a8915..7f4a45804b 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -13,7 +13,16 @@ from pydantic_ai import Agent, RunContext, Tool, UserError from pydantic_ai.exceptions import ModelRetry, UnexpectedModelBehavior -from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, TextPart, ToolCallPart, ToolReturnPart +from pydantic_ai.messages import ( + ModelMessage, + ModelRequest, + ModelResponse, + TextPart, + ToolCallPart, + ToolReturn, + ToolReturnPart, + UserPromptPart, +) from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel from pydantic_ai.output import DeferredToolCalls, ToolOutput @@ -21,8 +30,9 @@ from pydantic_ai.toolsets.deferred import DeferredToolset from pydantic_ai.toolsets.function import FunctionToolset from pydantic_ai.toolsets.prefixed import PrefixedToolset +from pydantic_ai.usage import Usage -from .conftest import IsStr +from .conftest import IsDatetime, IsStr def test_tool_no_ctx(): @@ -1321,3 +1331,91 @@ def test_output_type_deferred_tool_calls_by_itself(): def test_output_type_empty(): with pytest.raises(UserError, match='At least one output type must be provided.'): Agent(TestModel(), output_type=[]) + + +def test_parallel_tool_return(): + def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + return ModelResponse( + parts=[ToolCallPart('get_price', {'fruit': 'apple'}), ToolCallPart('get_price', {'fruit': 'banana'})] + ) + else: + return ModelResponse( + parts=[ + TextPart('Done!'), + ] + ) + + agent = Agent(FunctionModel(llm)) + + @agent.tool_plain + def get_price(fruit: str) -> ToolReturn: + return ToolReturn( + return_value=10.0, + content=f'The price of {fruit} is 10.0', + metadata={'foo': 'bar'}, + ) + + result = agent.run_sync('What do an apple and a banana cost?') + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What do an apple and a banana cost?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='get_price', + args={'fruit': 'apple'}, + tool_call_id=IsStr(), + ), + ToolCallPart( + tool_name='get_price', + args={'fruit': 'banana'}, + tool_call_id=IsStr(), + ), + ], + usage=Usage(requests=1, request_tokens=58, response_tokens=10, total_tokens=68), + model_name='function:llm:', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_price', + content=10.0, + tool_call_id=IsStr(), + metadata={'foo': 'bar'}, + timestamp=IsDatetime(), + ), + ToolReturnPart( + tool_name='get_price', + content=10.0, + tool_call_id=IsStr(), + metadata={'foo': 'bar'}, + timestamp=IsDatetime(), + ), + UserPromptPart( + content='The price of apple is 10.0', + timestamp=IsDatetime(), + ), + UserPromptPart( + content='The price of banana is 10.0', + timestamp=IsDatetime(), + ), + ] + ), + ModelResponse( + parts=[TextPart(content='Done!')], + usage=Usage(requests=1, request_tokens=76, response_tokens=11, total_tokens=87), + model_name='function:llm:', + timestamp=IsDatetime(), + ), + ] + )