Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 9 additions & 29 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model
from .output import OutputDataT, OutputSpec
from .profiles import ModelProfile
from .result import FinalResult, StreamedRunResult
from .result import AgentStream, FinalResult, StreamedRunResult
from .settings import ModelSettings, merge_model_settings
from .tools import (
AgentDepsT,
Expand Down Expand Up @@ -1127,29 +1127,15 @@ async def main():
while True:
if self.is_model_request_node(node):
graph_ctx = agent_run.ctx
async with node._stream(graph_ctx) as streamed_response: # pyright: ignore[reportPrivateUsage]

async def stream_to_final(
s: models.StreamedResponse,
) -> FinalResult[models.StreamedResponse] | None:
output_schema = graph_ctx.deps.output_schema
async for maybe_part_event in streamed_response:
if isinstance(maybe_part_event, _messages.PartStartEvent):
new_part = maybe_part_event.part
if isinstance(new_part, _messages.TextPart) and isinstance(
output_schema, _output.TextOutputSchema
):
return FinalResult(s, None, None)
elif isinstance(new_part, _messages.ToolCallPart) and (
tool_def := graph_ctx.deps.tool_manager.get_tool_def(new_part.tool_name)
):
if tool_def.kind == 'output':
return FinalResult(s, new_part.tool_name, new_part.tool_call_id)
elif tool_def.kind == 'deferred':
return FinalResult(s, None, None)
async with node.stream(graph_ctx) as stream:

async def stream_to_final(s: AgentStream) -> FinalResult[AgentStream] | None:
async for event in stream:
if isinstance(event, _messages.FinalResultEvent):
return FinalResult(s, event.tool_name, event.tool_call_id)
return None

final_result = await stream_to_final(streamed_response)
final_result = await stream_to_final(stream)
if final_result is not None:
if yielded:
raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover
Expand Down Expand Up @@ -1184,14 +1170,8 @@ async def on_complete() -> None:
yield StreamedRunResult(
messages,
graph_ctx.deps.new_message_index,
graph_ctx.deps.usage_limits,
streamed_response,
graph_ctx.deps.output_schema,
_agent_graph.build_run_context(graph_ctx),
graph_ctx.deps.output_validators,
final_result.tool_name,
stream,
on_complete,
graph_ctx.deps.tool_manager,
)
break
next_node = await agent_run.next(node)
Expand Down
Loading