@@ -291,12 +291,12 @@ async def run(
291291 if isinstance (deps , StateHandler ):
292292 deps .state = run_input .state
293293
294- history = _History . from_ag_ui (run_input .messages )
294+ messages = _messages_from_ag_ui (run_input .messages )
295295
296296 async with self .agent .iter (
297297 user_prompt = None ,
298298 output_type = [output_type or self .agent .output_type , DeferredToolCalls ],
299- message_history = history . messages ,
299+ message_history = messages ,
300300 model = model ,
301301 deps = deps ,
302302 model_settings = model_settings ,
@@ -305,7 +305,7 @@ async def run(
305305 infer_name = infer_name ,
306306 toolsets = toolsets ,
307307 ) as run :
308- async for event in self ._agent_stream (run , history ):
308+ async for event in self ._agent_stream (run ):
309309 yield encoder .encode (event )
310310 except _RunError as e :
311311 yield encoder .encode (
@@ -327,20 +327,18 @@ async def run(
327327 async def _agent_stream (
328328 self ,
329329 run : AgentRun [AgentDepsT , Any ],
330- history : _History ,
331330 ) -> AsyncGenerator [BaseEvent , None ]:
332331 """Run the agent streaming responses using AG-UI protocol events.
333332
334333 Args:
335334 run: The agent run to process.
336- history: The history of messages and tool calls to use for the run.
337335
338336 Yields:
339337 AG-UI Server-Sent Events (SSE).
340338 """
341339 async for node in run :
340+ stream_ctx = _RequestStreamContext ()
342341 if isinstance (node , ModelRequestNode ):
343- stream_ctx = _RequestStreamContext ()
344342 async with node .stream (run .ctx ) as request_stream :
345343 async for agent_event in request_stream :
346344 async for msg in self ._handle_model_request_event (stream_ctx , agent_event ):
@@ -352,8 +350,8 @@ async def _agent_stream(
352350 elif isinstance (node , CallToolsNode ):
353351 async with node .stream (run .ctx ) as handle_stream :
354352 async for event in handle_stream :
355- if isinstance (event , FunctionToolResultEvent ) and isinstance ( event . result , ToolReturnPart ) :
356- async for msg in self ._handle_tool_result_event (event . result , history . prompt_message_id ):
353+ if isinstance (event , FunctionToolResultEvent ):
354+ async for msg in self ._handle_tool_result_event (stream_ctx , event ):
357355 yield msg
358356
359357 async def _handle_model_request_event (
@@ -391,9 +389,11 @@ async def _handle_model_request_event(
391389 delta = part .content ,
392390 )
393391 elif isinstance (part , ToolCallPart ): # pragma: no branch
392+ message_id = stream_ctx .message_id or stream_ctx .new_message_id ()
394393 yield ToolCallStartEvent (
395394 tool_call_id = part .tool_call_id ,
396395 tool_call_name = part .tool_name ,
396+ parent_message_id = message_id ,
397397 )
398398 stream_ctx .part_end = ToolCallEndEvent (
399399 tool_call_id = part .tool_call_id ,
@@ -403,11 +403,9 @@ async def _handle_model_request_event(
403403 yield ThinkingTextMessageStartEvent (
404404 type = EventType .THINKING_TEXT_MESSAGE_START ,
405405 )
406- # Always send the content even if it's empty, as it may be
407- # used to indicate the start of thinking.
408406 yield ThinkingTextMessageContentEvent (
409407 type = EventType .THINKING_TEXT_MESSAGE_CONTENT ,
410- delta = part .content or '' ,
408+ delta = part .content ,
411409 )
412410 stream_ctx .part_end = ThinkingTextMessageEndEvent (
413411 type = EventType .THINKING_TEXT_MESSAGE_END ,
@@ -435,20 +433,25 @@ async def _handle_model_request_event(
435433
436434 async def _handle_tool_result_event (
437435 self ,
438- result : ToolReturnPart ,
439- prompt_message_id : str ,
436+ stream_ctx : _RequestStreamContext ,
437+ event : FunctionToolResultEvent ,
440438 ) -> AsyncGenerator [BaseEvent , None ]:
441439 """Convert a tool call result to AG-UI events.
442440
443441 Args:
444- result : The tool call result to process .
445- prompt_message_id : The message ID of the prompt that initiated the tool call .
442+ stream_ctx : The request stream context to manage state .
443+ event : The tool call result event to process .
446444
447445 Yields:
448446 AG-UI Server-Sent Events (SSE).
449447 """
448+ result = event .result
449+ if not isinstance (result , ToolReturnPart ):
450+ return
451+
452+ message_id = stream_ctx .new_message_id ()
450453 yield ToolCallResultEvent (
451- message_id = prompt_message_id ,
454+ message_id = message_id ,
452455 type = EventType .TOOL_CALL_RESULT ,
453456 role = 'tool' ,
454457 tool_call_id = result .tool_call_id ,
@@ -468,75 +471,55 @@ async def _handle_tool_result_event(
468471 yield item
469472
470473
471- @dataclass
472- class _History :
473- """A simple history representation for AG-UI protocol."""
474-
475- prompt_message_id : str # The ID of the last user message.
476- messages : list [ModelMessage ]
477-
478- @classmethod
479- def from_ag_ui (cls , messages : list [Message ]) -> _History :
480- """Convert a AG-UI history to a Pydantic AI one.
481-
482- Args:
483- messages: List of AG-UI messages to convert.
484-
485- Returns:
486- List of Pydantic AI model messages.
487- """
488- prompt_message_id = ''
489- result : list [ModelMessage ] = []
490- tool_calls : dict [str , str ] = {} # Tool call ID to tool name mapping.
491- for msg in messages :
492- if isinstance (msg , UserMessage ):
493- prompt_message_id = msg .id
494- result .append (ModelRequest (parts = [UserPromptPart (content = msg .content )]))
495- elif isinstance (msg , AssistantMessage ):
496- if msg .tool_calls :
497- for tool_call in msg .tool_calls :
498- tool_calls [tool_call .id ] = tool_call .function .name
499-
500- result .append (
501- ModelResponse (
502- parts = [
503- ToolCallPart (
504- tool_name = tool_call .function .name ,
505- tool_call_id = tool_call .id ,
506- args = tool_call .function .arguments ,
507- )
508- for tool_call in msg .tool_calls
509- ]
510- )
511- )
512-
513- if msg .content :
514- result .append (ModelResponse (parts = [TextPart (content = msg .content )]))
515- elif isinstance (msg , SystemMessage ):
516- result .append (ModelRequest (parts = [SystemPromptPart (content = msg .content )]))
517- elif isinstance (msg , ToolMessage ):
518- tool_name = tool_calls .get (msg .tool_call_id )
519- if tool_name is None : # pragma: no cover
520- raise _ToolCallNotFoundError (tool_call_id = msg .tool_call_id )
474+ def _messages_from_ag_ui (messages : list [Message ]) -> list [ModelMessage ]:
475+ """Convert a AG-UI history to a Pydantic AI one."""
476+ result : list [ModelMessage ] = []
477+ tool_calls : dict [str , str ] = {} # Tool call ID to tool name mapping.
478+ for msg in messages :
479+ if isinstance (msg , UserMessage ):
480+ result .append (ModelRequest (parts = [UserPromptPart (content = msg .content )]))
481+ elif isinstance (msg , AssistantMessage ):
482+ if msg .tool_calls :
483+ for tool_call in msg .tool_calls :
484+ tool_calls [tool_call .id ] = tool_call .function .name
521485
522486 result .append (
523- ModelRequest (
487+ ModelResponse (
524488 parts = [
525- ToolReturnPart (
526- tool_name = tool_name ,
527- content = msg . content ,
528- tool_call_id = msg . tool_call_id ,
489+ ToolCallPart (
490+ tool_name = tool_call . function . name ,
491+ tool_call_id = tool_call . id ,
492+ args = tool_call . function . arguments ,
529493 )
494+ for tool_call in msg .tool_calls
530495 ]
531496 )
532497 )
533- elif isinstance (msg , DeveloperMessage ): # pragma: no branch
534- result .append (ModelRequest (parts = [SystemPromptPart (content = msg .content )]))
535498
536- return cls (
537- prompt_message_id = prompt_message_id ,
538- messages = result ,
539- )
499+ if msg .content :
500+ result .append (ModelResponse (parts = [TextPart (content = msg .content )]))
501+ elif isinstance (msg , SystemMessage ):
502+ result .append (ModelRequest (parts = [SystemPromptPart (content = msg .content )]))
503+ elif isinstance (msg , ToolMessage ):
504+ tool_name = tool_calls .get (msg .tool_call_id )
505+ if tool_name is None : # pragma: no cover
506+ raise _ToolCallNotFoundError (tool_call_id = msg .tool_call_id )
507+
508+ result .append (
509+ ModelRequest (
510+ parts = [
511+ ToolReturnPart (
512+ tool_name = tool_name ,
513+ content = msg .content ,
514+ tool_call_id = msg .tool_call_id ,
515+ )
516+ ]
517+ )
518+ )
519+ elif isinstance (msg , DeveloperMessage ): # pragma: no branch
520+ result .append (ModelRequest (parts = [SystemPromptPart (content = msg .content )]))
521+
522+ return result
540523
541524
542525@runtime_checkable
0 commit comments