Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 21 additions & 19 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,11 +659,11 @@ async def process_function_tools( # noqa: C901
for call in calls_to_run:
yield _messages.FunctionToolCallEvent(call)

user_parts: list[_messages.UserPromptPart] = []
user_parts_by_index: dict[int, list[_messages.UserPromptPart]] = defaultdict(list)

if calls_to_run:
# Run all tool tasks in parallel
parts_by_index: dict[int, list[_messages.ModelRequestPart]] = {}
tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {}
with ctx.deps.tracer.start_as_current_span(
'running tools',
attributes={
Expand All @@ -681,15 +681,16 @@ async def process_function_tools( # noqa: C901
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
for task in done:
index = tasks.index(task)
tool_result_part, extra_parts = task.result()
yield _messages.FunctionToolResultEvent(tool_result_part)
tool_part, tool_user_parts = task.result()
yield _messages.FunctionToolResultEvent(tool_part)

parts_by_index[index] = [tool_result_part, *extra_parts]
tool_parts_by_index[index] = tool_part
user_parts_by_index[index] = tool_user_parts

# We append the results at the end, rather than as they are received, to retain a consistent ordering
# This is mostly just to simplify testing
for k in sorted(parts_by_index):
output_parts.extend(parts_by_index[k])
for k in sorted(tool_parts_by_index):
output_parts.append(tool_parts_by_index[k])

# Finally, we handle deferred tool calls
for call in tool_calls_by_kind['deferred']:
Expand All @@ -704,7 +705,8 @@ async def process_function_tools( # noqa: C901
else:
yield _messages.FunctionToolCallEvent(call)

output_parts.extend(user_parts)
for k in sorted(user_parts_by_index):
output_parts.extend(user_parts_by_index[k])

if final_result:
output_final_result.append(final_result)
Expand All @@ -713,18 +715,18 @@ async def process_function_tools( # noqa: C901
async def _call_function_tool(
tool_manager: ToolManager[DepsT],
tool_call: _messages.ToolCallPart,
) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, list[_messages.ModelRequestPart]]:
) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, list[_messages.UserPromptPart]]:
try:
tool_result = await tool_manager.handle_call(tool_call)
except ToolRetryError as e:
return (e.tool_retry, [])

part = _messages.ToolReturnPart(
tool_part = _messages.ToolReturnPart(
tool_name=tool_call.tool_name,
content=tool_result,
tool_call_id=tool_call.tool_call_id,
)
extra_parts: list[_messages.ModelRequestPart] = []
user_parts: list[_messages.UserPromptPart] = []

if isinstance(tool_result, _messages.ToolReturn):
if (
Expand All @@ -740,12 +742,12 @@ async def _call_function_tool(
f'Please use `content` instead.'
)

part.content = tool_result.return_value # type: ignore
part.metadata = tool_result.metadata
tool_part.content = tool_result.return_value # type: ignore
tool_part.metadata = tool_result.metadata
if tool_result.content:
extra_parts.append(
user_parts.append(
_messages.UserPromptPart(
content=list(tool_result.content),
content=tool_result.content,
part_kind='user-prompt',
)
)
Expand All @@ -763,7 +765,7 @@ def process_content(content: Any) -> Any:
else:
identifier = multi_modal_content_identifier(content.url)

extra_parts.append(
user_parts.append(
_messages.UserPromptPart(
content=[f'This is file {identifier}:', content],
part_kind='user-prompt',
Expand All @@ -775,11 +777,11 @@ def process_content(content: Any) -> Any:

if isinstance(tool_result, list):
contents = cast(list[Any], tool_result)
part.content = [process_content(content) for content in contents]
tool_part.content = [process_content(content) for content in contents]
else:
part.content = process_content(tool_result)
tool_part.content = process_content(tool_result)

return (part, extra_parts)
return (tool_part, user_parts)


@dataclasses.dataclass
Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,8 @@ class ToolReturn:
return_value: Any
"""The return value to be used in the tool response."""

content: Sequence[UserContent] | None = None
"""The content sequence to be sent to the model as a UserPromptPart."""
content: str | Sequence[UserContent] | None = None
"""The content to be sent to the model as a UserPromptPart."""

metadata: Any = None
"""Additional data that can be accessed programmatically by the application but is not sent to the LLM."""
Expand Down
102 changes: 100 additions & 2 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,26 @@

from pydantic_ai import Agent, RunContext, Tool, UserError
from pydantic_ai.exceptions import ModelRetry, UnexpectedModelBehavior
from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, TextPart, ToolCallPart, ToolReturnPart
from pydantic_ai.messages import (
ModelMessage,
ModelRequest,
ModelResponse,
TextPart,
ToolCallPart,
ToolReturn,
ToolReturnPart,
UserPromptPart,
)
from pydantic_ai.models.function import AgentInfo, FunctionModel
from pydantic_ai.models.test import TestModel
from pydantic_ai.output import DeferredToolCalls, ToolOutput
from pydantic_ai.tools import ToolDefinition
from pydantic_ai.toolsets.deferred import DeferredToolset
from pydantic_ai.toolsets.function import FunctionToolset
from pydantic_ai.toolsets.prefixed import PrefixedToolset
from pydantic_ai.usage import Usage

from .conftest import IsStr
from .conftest import IsDatetime, IsStr


def test_tool_no_ctx():
Expand Down Expand Up @@ -1321,3 +1331,91 @@ def test_output_type_deferred_tool_calls_by_itself():
def test_output_type_empty():
with pytest.raises(UserError, match='At least one output type must be provided.'):
Agent(TestModel(), output_type=[])


def test_parallel_tool_return():
def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
if len(messages) == 1:
return ModelResponse(
parts=[ToolCallPart('get_price', {'fruit': 'apple'}), ToolCallPart('get_price', {'fruit': 'banana'})]
)
else:
return ModelResponse(
parts=[
TextPart('Done!'),
]
)

agent = Agent(FunctionModel(llm))

@agent.tool_plain
def get_price(fruit: str) -> ToolReturn:
return ToolReturn(
return_value=10.0,
content=f'The price of {fruit} is 10.0',
metadata={'foo': 'bar'},
)

result = agent.run_sync('What do an apple and a banana cost?')

assert result.all_messages() == snapshot(
[
ModelRequest(
parts=[
UserPromptPart(
content='What do an apple and a banana cost?',
timestamp=IsDatetime(),
)
]
),
ModelResponse(
parts=[
ToolCallPart(
tool_name='get_price',
args={'fruit': 'apple'},
tool_call_id=IsStr(),
),
ToolCallPart(
tool_name='get_price',
args={'fruit': 'banana'},
tool_call_id=IsStr(),
),
],
usage=Usage(requests=1, request_tokens=58, response_tokens=10, total_tokens=68),
model_name='function:llm:',
timestamp=IsDatetime(),
),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='get_price',
content=10.0,
tool_call_id=IsStr(),
metadata={'foo': 'bar'},
timestamp=IsDatetime(),
),
ToolReturnPart(
tool_name='get_price',
content=10.0,
tool_call_id=IsStr(),
metadata={'foo': 'bar'},
timestamp=IsDatetime(),
),
UserPromptPart(
content='The price of apple is 10.0',
timestamp=IsDatetime(),
),
UserPromptPart(
content='The price of banana is 10.0',
timestamp=IsDatetime(),
),
]
),
ModelResponse(
parts=[TextPart(content='Done!')],
usage=Usage(requests=1, request_tokens=76, response_tokens=11, total_tokens=87),
model_name='function:llm:',
timestamp=IsDatetime(),
),
]
)