diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 9a68bc6775..9f6348c6c9 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -36,7 +36,7 @@ from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model from .output import OutputDataT, OutputSpec from .profiles import ModelProfile -from .result import FinalResult, StreamedRunResult +from .result import AgentStream, FinalResult, StreamedRunResult from .settings import ModelSettings, merge_model_settings from .tools import ( AgentDepsT, @@ -1127,29 +1127,15 @@ async def main(): while True: if self.is_model_request_node(node): graph_ctx = agent_run.ctx - async with node._stream(graph_ctx) as streamed_response: # pyright: ignore[reportPrivateUsage] - - async def stream_to_final( - s: models.StreamedResponse, - ) -> FinalResult[models.StreamedResponse] | None: - output_schema = graph_ctx.deps.output_schema - async for maybe_part_event in streamed_response: - if isinstance(maybe_part_event, _messages.PartStartEvent): - new_part = maybe_part_event.part - if isinstance(new_part, _messages.TextPart) and isinstance( - output_schema, _output.TextOutputSchema - ): - return FinalResult(s, None, None) - elif isinstance(new_part, _messages.ToolCallPart) and ( - tool_def := graph_ctx.deps.tool_manager.get_tool_def(new_part.tool_name) - ): - if tool_def.kind == 'output': - return FinalResult(s, new_part.tool_name, new_part.tool_call_id) - elif tool_def.kind == 'deferred': - return FinalResult(s, None, None) + async with node.stream(graph_ctx) as stream: + + async def stream_to_final(s: AgentStream) -> FinalResult[AgentStream] | None: + async for event in stream: + if isinstance(event, _messages.FinalResultEvent): + return FinalResult(s, event.tool_name, event.tool_call_id) return None - final_result = await stream_to_final(streamed_response) + final_result = await stream_to_final(stream) if final_result is not None: if yielded: raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover @@ -1184,14 +1170,8 @@ async def on_complete() -> None: yield StreamedRunResult( messages, graph_ctx.deps.new_message_index, - graph_ctx.deps.usage_limits, - streamed_response, - graph_ctx.deps.output_schema, - _agent_graph.build_run_context(graph_ctx), - graph_ctx.deps.output_validators, - final_result.tool_name, + stream, on_complete, - graph_ctx.deps.tool_manager, ) break next_node = await agent_run.next(node) diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index d8439fb5d7..e640302b24 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -63,22 +63,18 @@ async def stream_output(self, *, debounce_by: float | None = 0.1) -> AsyncIterat async for response in self.stream_responses(debounce_by=debounce_by): if self._final_result_event is not None: try: - yield await self._validate_response( - response, self._final_result_event.tool_name, allow_partial=True - ) + yield await self._validate_response(response, allow_partial=True) except ValidationError: pass if self._final_result_event is not None: # pragma: no branch - yield await self._validate_response( - self._raw_stream_response.get(), self._final_result_event.tool_name, allow_partial=False - ) + yield await self._validate_response(self._raw_stream_response.get(), allow_partial=False) async def stream_responses(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]: """Asynchronously stream the (unvalidated) model responses for the agent.""" # if the message currently has any parts with content, yield before streaming msg = self._raw_stream_response.get() for part in msg.parts: - if part.has_content(): # pragma: no cover + if part.has_content(): yield msg break @@ -86,6 +82,35 @@ async def stream_responses(self, *, debounce_by: float | None = 0.1) -> AsyncIte async for _items in group_iter: yield self._raw_stream_response.get() # current state of the response + async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]: + """Stream the text result as an async iterable. + + !!! note + Result validators will NOT be called on the text result if `delta=True`. + + Args: + delta: if `True`, yield each chunk of text as it is received, if `False` (default), yield the full text + up to the current point. + debounce_by: by how much (if at all) to debounce/group the response chunks by. `None` means no debouncing. + Debouncing is particularly important for long structured responses to reduce the overhead of + performing validation as each token is received. + """ + if not isinstance(self._output_schema, PlainTextOutputSchema): + raise exceptions.UserError('stream_text() can only be used with text responses') + + if delta: + async for text in self._stream_response_text(delta=True, debounce_by=debounce_by): + yield text + else: + async for text in self._stream_response_text(delta=False, debounce_by=debounce_by): + for validator in self._output_validators: + text = await validator.validate(text, self._run_ctx) # pragma: no cover + yield text + + def get(self) -> _messages.ModelResponse: + """Get the current state of the response.""" + return self._raw_stream_response.get() + def usage(self) -> Usage: """Return the usage of the whole run. @@ -94,10 +119,24 @@ def usage(self) -> Usage: """ return self._initial_run_ctx_usage + self._raw_stream_response.usage() - async def _validate_response( - self, message: _messages.ModelResponse, output_tool_name: str | None, *, allow_partial: bool = False - ) -> OutputDataT: + def timestamp(self) -> datetime: + """Get the timestamp of the response.""" + return self._raw_stream_response.timestamp + + async def get_output(self) -> OutputDataT: + """Stream the whole response, validate the output and return it.""" + async for _ in self: + pass + + return await self._validate_response(self._raw_stream_response.get(), allow_partial=False) + + async def _validate_response(self, message: _messages.ModelResponse, *, allow_partial: bool = False) -> OutputDataT: """Validate a structured result message.""" + if self._final_result_event is None: + raise exceptions.UnexpectedModelBehavior('Invalid response, unable to find output') # pragma: no cover + + output_tool_name = self._final_result_event.tool_name + if isinstance(self._output_schema, ToolOutputSchema) and output_tool_name is not None: tool_call = next( ( @@ -114,7 +153,7 @@ async def _validate_response( return await self._tool_manager.handle_call(tool_call, allow_partial=allow_partial) elif deferred_tool_calls := self._tool_manager.get_deferred_tool_calls(message.parts): if not self._output_schema.allows_deferred_tool_calls: - raise exceptions.UserError( # pragma: no cover + raise exceptions.UserError( 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.' ) return cast(OutputDataT, deferred_tool_calls) @@ -132,6 +171,54 @@ async def _validate_response( 'Invalid response, unable to process text output' ) + async def _stream_response_text( + self, *, delta: bool = False, debounce_by: float | None = 0.1 + ) -> AsyncIterator[str]: + """Stream the response as an async iterable of text.""" + + # Define a "merged" version of the iterator that will yield items that have already been retrieved + # and items that we receive while streaming. We define a dedicated async iterator for this so we can + # pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below. + async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]: + # yields tuples of (text_content, part_index) + # we don't currently make use of the part_index, but in principle this may be useful + # so we retain it here for now to make possible future refactors simpler + msg = self._raw_stream_response.get() + for i, part in enumerate(msg.parts): + if isinstance(part, _messages.TextPart) and part.content: + yield part.content, i + + async for event in self._raw_stream_response: + if ( + isinstance(event, _messages.PartStartEvent) + and isinstance(event.part, _messages.TextPart) + and event.part.content + ): + yield event.part.content, event.index # pragma: no cover + elif ( # pragma: no branch + isinstance(event, _messages.PartDeltaEvent) + and isinstance(event.delta, _messages.TextPartDelta) + and event.delta.content_delta + ): + yield event.delta.content_delta, event.index + + async def _stream_text_deltas() -> AsyncIterator[str]: + async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter: + async for items in group_iter: + # Note: we are currently just dropping the part index on the group here + yield ''.join([content for content, _ in items]) + + if delta: + async for text in _stream_text_deltas(): + yield text + else: + # a quick benchmark shows it's faster to build up a string with concat when we're + # yielding at each step + deltas: list[str] = [] + async for text in _stream_text_deltas(): + deltas.append(text) + yield ''.join(deltas) + def __aiter__(self) -> AsyncIterator[AgentStreamEvent]: """Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s. @@ -189,16 +276,9 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]): _all_messages: list[_messages.ModelMessage] _new_message_index: int - _usage_limits: UsageLimits | None - _stream_response: models.StreamedResponse - _output_schema: OutputSchema[OutputDataT] - _run_ctx: RunContext[AgentDepsT] - _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] - _output_tool_name: str | None + _stream_response: AgentStream[AgentDepsT, OutputDataT] _on_complete: Callable[[], Awaitable[None]] - _tool_manager: ToolManager[AgentDepsT] - _initial_run_ctx_usage: Usage = field(init=False) is_complete: bool = field(default=False, init=False) """Whether the stream has all been received. @@ -209,9 +289,6 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]): [`get_output`][pydantic_ai.result.StreamedRunResult.get_output] completes. """ - def __post_init__(self): - self._initial_run_ctx_usage = copy(self._run_ctx.usage) - @overload def all_messages(self, *, output_tool_return_content: str | None = None) -> list[_messages.ModelMessage]: ... @@ -332,12 +409,9 @@ async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[Outp Returns: An async iterable of the response data. """ - async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by): - try: - yield await self.validate_structured_output(structured_message, allow_partial=not is_last) - except ValidationError: - if is_last: - raise # pragma: no cover + async for output in self._stream_response.stream_output(debounce_by=debounce_by): + yield output + await self._marked_completed(self._stream_response.get()) async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]: """Stream the text result as an async iterable. @@ -352,16 +426,8 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None = Debouncing is particularly important for long structured responses to reduce the overhead of performing validation as each token is received. """ - if not isinstance(self._output_schema, PlainTextOutputSchema): - raise exceptions.UserError('stream_text() can only be used with text responses') - - if delta: - async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by): - yield text - else: - async for text in self._stream_response_text(delta=delta, debounce_by=debounce_by): - combined_validated_text = await self._validate_text_output(text) - yield combined_validated_text + async for text in self._stream_response.stream_text(delta=delta, debounce_by=debounce_by): + yield text await self._marked_completed(self._stream_response.get()) async def stream_structured( @@ -378,13 +444,7 @@ async def stream_structured( An async iterable of the structured response message and whether that is the last message. """ # if the message currently has any parts with content, yield before streaming - msg = self._stream_response.get() - for part in msg.parts: - if part.has_content(): - yield msg, False - break - - async for msg in self._stream_response_structured(debounce_by=debounce_by): + async for msg in self._stream_response.stream_responses(debounce_by=debounce_by): yield msg, False msg = self._stream_response.get() @@ -394,15 +454,9 @@ async def stream_structured( async def get_output(self) -> OutputDataT: """Stream the whole response, validate and return it.""" - usage_checking_stream = _get_usage_checking_stream_response( - self._stream_response, self._usage_limits, self.usage - ) - - async for _ in usage_checking_stream: - pass - message = self._stream_response.get() - await self._marked_completed(message) - return await self.validate_structured_output(message) + output = await self._stream_response.get_output() + await self._marked_completed(self._stream_response.get()) + return output @deprecated('`get_data` is deprecated, use `get_output` instead.') async def get_data(self) -> OutputDataT: @@ -414,11 +468,11 @@ def usage(self) -> Usage: !!! note This won't return the full usage until the stream is finished. """ - return self._initial_run_ctx_usage + self._stream_response.usage() + return self._stream_response.usage() def timestamp(self) -> datetime: """Get the timestamp of the response.""" - return self._stream_response.timestamp + return self._stream_response.timestamp() @deprecated('`validate_structured_result` is deprecated, use `validate_structured_output` instead.') async def validate_structured_result( @@ -430,105 +484,15 @@ async def validate_structured_output( self, message: _messages.ModelResponse, *, allow_partial: bool = False ) -> OutputDataT: """Validate a structured result message.""" - if isinstance(self._output_schema, ToolOutputSchema) and self._output_tool_name is not None: - tool_call = next( - ( - part - for part in message.parts - if isinstance(part, _messages.ToolCallPart) and part.tool_name == self._output_tool_name - ), - None, - ) - if tool_call is None: - raise exceptions.UnexpectedModelBehavior( # pragma: no cover - f'Invalid response, unable to find tool call for {self._output_tool_name!r}' - ) - return await self._tool_manager.handle_call(tool_call, allow_partial=allow_partial) - elif deferred_tool_calls := self._tool_manager.get_deferred_tool_calls(message.parts): - if not self._output_schema.allows_deferred_tool_calls: - raise exceptions.UserError( - 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.' - ) - return cast(OutputDataT, deferred_tool_calls) - elif isinstance(self._output_schema, TextOutputSchema): - text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) - - result_data = await self._output_schema.process( - text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False - ) - for validator in self._output_validators: - result_data = await validator.validate(result_data, self._run_ctx) # pragma: no cover - return result_data - else: - raise exceptions.UnexpectedModelBehavior( # pragma: no cover - 'Invalid response, unable to process text output' - ) - - async def _validate_text_output(self, text: str) -> str: - for validator in self._output_validators: - text = await validator.validate(text, self._run_ctx) # pragma: no cover - return text + return await self._stream_response._validate_response( # pyright: ignore[reportPrivateUsage] + message, allow_partial=allow_partial + ) async def _marked_completed(self, message: _messages.ModelResponse) -> None: self.is_complete = True self._all_messages.append(message) await self._on_complete() - async def _stream_response_structured( - self, *, debounce_by: float | None = 0.1 - ) -> AsyncIterator[_messages.ModelResponse]: - async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter: - async for _items in group_iter: - yield self._stream_response.get() - - async def _stream_response_text( - self, *, delta: bool = False, debounce_by: float | None = 0.1 - ) -> AsyncIterator[str]: - """Stream the response as an async iterable of text.""" - - # Define a "merged" version of the iterator that will yield items that have already been retrieved - # and items that we receive while streaming. We define a dedicated async iterator for this so we can - # pass the combined stream to the group_by_temporal function within `_stream_text_deltas` below. - async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]: - # yields tuples of (text_content, part_index) - # we don't currently make use of the part_index, but in principle this may be useful - # so we retain it here for now to make possible future refactors simpler - msg = self._stream_response.get() - for i, part in enumerate(msg.parts): - if isinstance(part, _messages.TextPart) and part.content: - yield part.content, i - - async for event in self._stream_response: - if ( - isinstance(event, _messages.PartStartEvent) - and isinstance(event.part, _messages.TextPart) - and event.part.content - ): - yield event.part.content, event.index # pragma: no cover - elif ( # pragma: no branch - isinstance(event, _messages.PartDeltaEvent) - and isinstance(event.delta, _messages.TextPartDelta) - and event.delta.content_delta - ): - yield event.delta.content_delta, event.index - - async def _stream_text_deltas() -> AsyncIterator[str]: - async with _utils.group_by_temporal(_stream_text_deltas_ungrouped(), debounce_by) as group_iter: - async for items in group_iter: - # Note: we are currently just dropping the part index on the group here - yield ''.join([content for content, _ in items]) - - if delta: - async for text in _stream_text_deltas(): - yield text - else: - # a quick benchmark shows it's faster to build up a string with concat when we're - # yielding at each step - deltas: list[str] = [] - async for text in _stream_text_deltas(): - deltas.append(text) - yield ''.join(deltas) - @dataclass(repr=False) class FinalResult(Generic[OutputDataT]): @@ -556,12 +520,12 @@ def _get_usage_checking_stream_response( ) -> AsyncIterable[_messages.ModelResponseStreamEvent]: if limits is not None and limits.has_token_limits(): - async def _usage_checking_iterator(): # pragma: no cover + async def _usage_checking_iterator(): async for item in stream_response: limits.check_tokens(get_usage()) yield item - return _usage_checking_iterator() # pragma: no cover + return _usage_checking_iterator() else: return stream_response