|
36 | 36 | from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model |
37 | 37 | from .output import OutputDataT, OutputSpec |
38 | 38 | from .profiles import ModelProfile |
39 | | -from .result import FinalResult, StreamedRunResult |
| 39 | +from .result import AgentStream, FinalResult, StreamedRunResult |
40 | 40 | from .settings import ModelSettings, merge_model_settings |
41 | 41 | from .tools import ( |
42 | 42 | AgentDepsT, |
@@ -1127,29 +1127,15 @@ async def main(): |
1127 | 1127 | while True: |
1128 | 1128 | if self.is_model_request_node(node): |
1129 | 1129 | graph_ctx = agent_run.ctx |
1130 | | - async with node._stream(graph_ctx) as streamed_response: # pyright: ignore[reportPrivateUsage] |
1131 | | - |
1132 | | - async def stream_to_final( |
1133 | | - s: models.StreamedResponse, |
1134 | | - ) -> FinalResult[models.StreamedResponse] | None: |
1135 | | - output_schema = graph_ctx.deps.output_schema |
1136 | | - async for maybe_part_event in streamed_response: |
1137 | | - if isinstance(maybe_part_event, _messages.PartStartEvent): |
1138 | | - new_part = maybe_part_event.part |
1139 | | - if isinstance(new_part, _messages.TextPart) and isinstance( |
1140 | | - output_schema, _output.TextOutputSchema |
1141 | | - ): |
1142 | | - return FinalResult(s, None, None) |
1143 | | - elif isinstance(new_part, _messages.ToolCallPart) and ( |
1144 | | - tool_def := graph_ctx.deps.tool_manager.get_tool_def(new_part.tool_name) |
1145 | | - ): |
1146 | | - if tool_def.kind == 'output': |
1147 | | - return FinalResult(s, new_part.tool_name, new_part.tool_call_id) |
1148 | | - elif tool_def.kind == 'deferred': |
1149 | | - return FinalResult(s, None, None) |
| 1130 | + async with node.stream(graph_ctx) as stream: |
| 1131 | + |
| 1132 | + async def stream_to_final(s: AgentStream) -> FinalResult[AgentStream] | None: |
| 1133 | + async for event in stream: |
| 1134 | + if isinstance(event, _messages.FinalResultEvent): |
| 1135 | + return FinalResult(s, event.tool_name, event.tool_call_id) |
1150 | 1136 | return None |
1151 | 1137 |
|
1152 | | - final_result = await stream_to_final(streamed_response) |
| 1138 | + final_result = await stream_to_final(stream) |
1153 | 1139 | if final_result is not None: |
1154 | 1140 | if yielded: |
1155 | 1141 | raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover |
@@ -1184,14 +1170,8 @@ async def on_complete() -> None: |
1184 | 1170 | yield StreamedRunResult( |
1185 | 1171 | messages, |
1186 | 1172 | graph_ctx.deps.new_message_index, |
1187 | | - graph_ctx.deps.usage_limits, |
1188 | | - streamed_response, |
1189 | | - graph_ctx.deps.output_schema, |
1190 | | - _agent_graph.build_run_context(graph_ctx), |
1191 | | - graph_ctx.deps.output_validators, |
1192 | | - final_result.tool_name, |
| 1173 | + stream, |
1193 | 1174 | on_complete, |
1194 | | - graph_ctx.deps.tool_manager, |
1195 | 1175 | ) |
1196 | 1176 | break |
1197 | 1177 | next_node = await agent_run.next(node) |
|
0 commit comments