Skip to content

Commit aa1cc5a

Browse files
authored
Fix ag_ui message toolcall order to support Anthropic requirements
Fixes #2327
1 parent 6207ac6 commit aa1cc5a

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

pydantic_ai_slim/pydantic_ai/ag_ui.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,6 @@ async def _handle_tool_result_event(
477477
if isinstance(item, BaseEvent): # pragma: no branch
478478
yield item
479479

480-
481480
def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]:
482481
"""Convert a AG-UI history to a Pydantic AI one."""
483482
result: list[ModelMessage] = []
@@ -486,7 +485,23 @@ def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]:
486485
if isinstance(msg, UserMessage):
487486
result.append(ModelRequest(parts=[UserPromptPart(content=msg.content)]))
488487
elif isinstance(msg, AssistantMessage):
489-
if msg.tool_calls:
488+
if msg.tool_calls and msg.content:
489+
# When both content and tool_calls exist, combine them in the correct order:
490+
# text content first, then tool calls (preserving original assistant behavior)
491+
for tool_call in msg.tool_calls:
492+
tool_calls[tool_call.id] = tool_call.function.name
493+
494+
parts = [TextPart(content=msg.content)]
495+
parts.extend([
496+
ToolCallPart(
497+
tool_name=tool_call.function.name,
498+
tool_call_id=tool_call.id,
499+
args=tool_call.function.arguments,
500+
)
501+
for tool_call in msg.tool_calls
502+
])
503+
result.append(ModelResponse(parts=parts))
504+
elif msg.tool_calls:
490505
for tool_call in msg.tool_calls:
491506
tool_calls[tool_call.id] = tool_call.function.name
492507

@@ -502,8 +517,7 @@ def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]:
502517
]
503518
)
504519
)
505-
506-
if msg.content:
520+
elif msg.content:
507521
result.append(ModelResponse(parts=[TextPart(content=msg.content)]))
508522
elif isinstance(msg, SystemMessage):
509523
result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)]))
@@ -528,7 +542,6 @@ def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]:
528542

529543
return result
530544

531-
532545
@runtime_checkable
533546
class StateHandler(Protocol):
534547
"""Protocol for state handlers in agent runs."""

0 commit comments

Comments
 (0)