2323from . import _output , _system_prompt , exceptions , messages as _messages , models , result , usage as _usage
2424from .exceptions import ToolRetryError
2525from .output import OutputDataT , OutputSpec
26- from .settings import ModelSettings , merge_model_settings
26+ from .settings import ModelSettings
2727from .tools import RunContext , ToolDefinition , ToolKind
2828
2929if TYPE_CHECKING :
@@ -158,28 +158,7 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
158158
159159 async def run (
160160 self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
161- ) -> ModelRequestNode [DepsT , NodeRunEndT ]:
162- return ModelRequestNode [DepsT , NodeRunEndT ](request = await self ._get_first_message (ctx ))
163-
164- async def _get_first_message (
165- self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
166- ) -> _messages .ModelRequest :
167- run_context = build_run_context (ctx )
168- history , next_message = await self ._prepare_messages (
169- self .user_prompt , ctx .state .message_history , ctx .deps .get_instructions , run_context
170- )
171- ctx .state .message_history = history
172- run_context .messages = history
173-
174- return next_message
175-
176- async def _prepare_messages (
177- self ,
178- user_prompt : str | Sequence [_messages .UserContent ] | None ,
179- message_history : list [_messages .ModelMessage ] | None ,
180- get_instructions : Callable [[RunContext [DepsT ]], Awaitable [str | None ]],
181- run_context : RunContext [DepsT ],
182- ) -> tuple [list [_messages .ModelMessage ], _messages .ModelRequest ]:
161+ ) -> Union [ModelRequestNode [DepsT , NodeRunEndT ], CallToolsNode [DepsT , NodeRunEndT ]]: # noqa UP007
183162 try :
184163 ctx_messages = get_captured_run_messages ()
185164 except LookupError :
@@ -191,29 +170,48 @@ async def _prepare_messages(
191170 messages = ctx_messages .messages
192171 ctx_messages .used = True
193172
173+ # Add message history to the `capture_run_messages` list, which will be empty at this point
174+ messages .extend (ctx .state .message_history )
175+ # Use the `capture_run_messages` list as the message history so that new messages are added to it
176+ ctx .state .message_history = messages
177+
178+ run_context = build_run_context (ctx )
179+
194180 parts : list [_messages .ModelRequestPart ] = []
195- instructions = await get_instructions (run_context )
196- if message_history :
197- # Shallow copy messages
198- messages .extend (message_history )
181+ if messages :
199182 # Reevaluate any dynamic system prompt parts
200183 await self ._reevaluate_dynamic_prompts (messages , run_context )
201184 else :
202185 parts .extend (await self ._sys_parts (run_context ))
203186
204- if user_prompt is not None :
205- parts .append (_messages .UserPromptPart (user_prompt ))
206- elif (
207- len (parts ) == 0
208- and message_history
209- and (last_message := message_history [- 1 ])
210- and isinstance (last_message , _messages .ModelRequest )
211- ):
212- # Drop last message that came from history and reuse its parts
213- messages .pop ()
214- parts .extend (last_message .parts )
187+ if messages and (last_message := messages [- 1 ]):
188+ if isinstance (last_message , _messages .ModelRequest ) and self .user_prompt is None :
189+ # Drop last message from history and reuse its parts
190+ messages .pop ()
191+ parts .extend (last_message .parts )
192+ elif isinstance (last_message , _messages .ModelResponse ):
193+ if self .user_prompt is None :
194+ # `CallToolsNode` requires the tool manager to be prepared for the run step
195+ # This will raise errors for any tool name conflicts
196+ ctx .deps .tool_manager = await ctx .deps .tool_manager .for_run_step (run_context )
197+
198+ # Skip ModelRequestNode and go directly to CallToolsNode
199+ return CallToolsNode [DepsT , NodeRunEndT ](model_response = last_message )
200+ elif any (isinstance (part , _messages .ToolCallPart ) for part in last_message .parts ):
201+ raise exceptions .UserError (
202+ 'Cannot provide a new user prompt when the message history ends with '
203+ 'a model response containing unprocessed tool calls. Either process the '
204+ 'tool calls first (by calling `iter` with `user_prompt=None`) or append a '
205+ '`ModelRequest` with `ToolResultPart`s.'
206+ )
207+
208+ if self .user_prompt is not None :
209+ parts .append (_messages .UserPromptPart (self .user_prompt ))
210+
211+ instructions = await ctx .deps .get_instructions (run_context )
212+ next_message = _messages .ModelRequest (parts , instructions = instructions )
215213
216- return messages , _messages . ModelRequest ( parts , instructions = instructions )
214+ return ModelRequestNode [ DepsT , NodeRunEndT ]( request = next_message )
217215
218216 async def _reevaluate_dynamic_prompts (
219217 self , messages : list [_messages .ModelMessage ], run_context : RunContext [DepsT ]
@@ -250,11 +248,6 @@ async def _prepare_request_parameters(
250248 ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]],
251249) -> models .ModelRequestParameters :
252250 """Build tools and create an agent model."""
253- run_context = build_run_context (ctx )
254-
255- # This will raise errors for any tool name conflicts
256- ctx .deps .tool_manager = await ctx .deps .tool_manager .for_run_step (run_context )
257-
258251 output_schema = ctx .deps .output_schema
259252 output_object = None
260253 if isinstance (output_schema , _output .NativeOutputSchema ):
@@ -357,21 +350,21 @@ async def _prepare_request(
357350
358351 run_context = build_run_context (ctx )
359352
360- model_settings = merge_model_settings (ctx .deps .model_settings , None )
353+ # This will raise errors for any tool name conflicts
354+ ctx .deps .tool_manager = await ctx .deps .tool_manager .for_run_step (run_context )
355+
356+ message_history = await _process_message_history (ctx .state , ctx .deps .history_processors , run_context )
361357
362358 model_request_parameters = await _prepare_request_parameters (ctx )
363359 model_request_parameters = ctx .deps .model .customize_request_parameters (model_request_parameters )
364360
365- message_history = await _process_message_history (ctx .state , ctx .deps .history_processors , run_context )
366-
361+ model_settings = ctx .deps .model_settings
367362 usage = ctx .state .usage
368363 if ctx .deps .usage_limits .count_tokens_before_request :
369364 # Copy to avoid modifying the original usage object with the counted usage
370365 usage = dataclasses .replace (usage )
371366
372- counted_usage = await ctx .deps .model .count_tokens (
373- message_history , ctx .deps .model_settings , model_request_parameters
374- )
367+ counted_usage = await ctx .deps .model .count_tokens (message_history , model_settings , model_request_parameters )
375368 usage .incr (counted_usage )
376369
377370 ctx .deps .usage_limits .check_before_request (usage )
0 commit comments