@@ -659,11 +659,11 @@ async def process_function_tools( # noqa: C901
659659 for call in calls_to_run :
660660 yield _messages .FunctionToolCallEvent (call )
661661
662- user_parts : list [_messages .UserPromptPart ] = []
662+ user_parts_by_index : dict [ int , list [_messages .UserPromptPart ]] = defaultdict ( list )
663663
664664 if calls_to_run :
665665 # Run all tool tasks in parallel
666- parts_by_index : dict [int , list [ _messages .ModelRequestPart ] ] = {}
666+ tool_parts_by_index : dict [int , _messages .ModelRequestPart ] = {}
667667 with ctx .deps .tracer .start_as_current_span (
668668 'running tools' ,
669669 attributes = {
@@ -681,15 +681,16 @@ async def process_function_tools( # noqa: C901
681681 done , pending = await asyncio .wait (pending , return_when = asyncio .FIRST_COMPLETED )
682682 for task in done :
683683 index = tasks .index (task )
684- tool_result_part , extra_parts = task .result ()
685- yield _messages .FunctionToolResultEvent (tool_result_part )
684+ tool_part , tool_user_parts = task .result ()
685+ yield _messages .FunctionToolResultEvent (tool_part )
686686
687- parts_by_index [index ] = [tool_result_part , * extra_parts ]
687+ tool_parts_by_index [index ] = tool_part
688+ user_parts_by_index [index ] = tool_user_parts
688689
689690 # We append the results at the end, rather than as they are received, to retain a consistent ordering
690691 # This is mostly just to simplify testing
691- for k in sorted (parts_by_index ):
692- output_parts .extend ( parts_by_index [k ])
692+ for k in sorted (tool_parts_by_index ):
693+ output_parts .append ( tool_parts_by_index [k ])
693694
694695 # Finally, we handle deferred tool calls
695696 for call in tool_calls_by_kind ['deferred' ]:
@@ -704,7 +705,8 @@ async def process_function_tools( # noqa: C901
704705 else :
705706 yield _messages .FunctionToolCallEvent (call )
706707
707- output_parts .extend (user_parts )
708+ for k in sorted (user_parts_by_index ):
709+ output_parts .extend (user_parts_by_index [k ])
708710
709711 if final_result :
710712 output_final_result .append (final_result )
@@ -713,18 +715,18 @@ async def process_function_tools( # noqa: C901
713715async def _call_function_tool (
714716 tool_manager : ToolManager [DepsT ],
715717 tool_call : _messages .ToolCallPart ,
716- ) -> tuple [_messages .ToolReturnPart | _messages .RetryPromptPart , list [_messages .ModelRequestPart ]]:
718+ ) -> tuple [_messages .ToolReturnPart | _messages .RetryPromptPart , list [_messages .UserPromptPart ]]:
717719 try :
718720 tool_result = await tool_manager .handle_call (tool_call )
719721 except ToolRetryError as e :
720722 return (e .tool_retry , [])
721723
722- part = _messages .ToolReturnPart (
724+ tool_part = _messages .ToolReturnPart (
723725 tool_name = tool_call .tool_name ,
724726 content = tool_result ,
725727 tool_call_id = tool_call .tool_call_id ,
726728 )
727- extra_parts : list [_messages .ModelRequestPart ] = []
729+ user_parts : list [_messages .UserPromptPart ] = []
728730
729731 if isinstance (tool_result , _messages .ToolReturn ):
730732 if (
@@ -740,12 +742,12 @@ async def _call_function_tool(
740742 f'Please use `content` instead.'
741743 )
742744
743- part .content = tool_result .return_value # type: ignore
744- part .metadata = tool_result .metadata
745+ tool_part .content = tool_result .return_value # type: ignore
746+ tool_part .metadata = tool_result .metadata
745747 if tool_result .content :
746- extra_parts .append (
748+ user_parts .append (
747749 _messages .UserPromptPart (
748- content = list ( tool_result .content ) ,
750+ content = tool_result .content ,
749751 part_kind = 'user-prompt' ,
750752 )
751753 )
@@ -763,7 +765,7 @@ def process_content(content: Any) -> Any:
763765 else :
764766 identifier = multi_modal_content_identifier (content .url )
765767
766- extra_parts .append (
768+ user_parts .append (
767769 _messages .UserPromptPart (
768770 content = [f'This is file { identifier } :' , content ],
769771 part_kind = 'user-prompt' ,
@@ -775,11 +777,11 @@ def process_content(content: Any) -> Any:
775777
776778 if isinstance (tool_result , list ):
777779 contents = cast (list [Any ], tool_result )
778- part .content = [process_content (content ) for content in contents ]
780+ tool_part .content = [process_content (content ) for content in contents ]
779781 else :
780- part .content = process_content (tool_result )
782+ tool_part .content = process_content (tool_result )
781783
782- return (part , extra_parts )
784+ return (tool_part , user_parts )
783785
784786
785787@dataclasses .dataclass
0 commit comments