diff --git a/docs/durable_execution/temporal.md b/docs/durable_execution/temporal.md index c29e178843..d39f17f055 100644 --- a/docs/durable_execution/temporal.md +++ b/docs/durable_execution/temporal.md @@ -172,7 +172,7 @@ As workflows and activities run in separate processes, any values passed between To account for these limitations, tool functions and the [event stream handler](#streaming) running inside activities receive a limited version of the agent's [`RunContext`][pydantic_ai.tools.RunContext], and it's your responsibility to make sure that the [dependencies](../dependencies.md) object provided to [`TemporalAgent.run()`][pydantic_ai.durable_exec.temporal.TemporalAgent.run] can be serialized using Pydantic. -Specifically, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, `max_retries` and `run_step` fields are available by default, and trying to access `model`, `usage`, `prompt`, `messages`, or `tracer` will raise an error. +Specifically, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, `max_retries`, `run_step` and `partial_output` fields are available by default, and trying to access `model`, `usage`, `prompt`, `messages`, or `tracer` will raise an error. If you need one or more of these attributes to be available inside activities, you can create a [`TemporalRunContext`][pydantic_ai.durable_exec.temporal.TemporalRunContext] subclass with custom `serialize_run_context` and `deserialize_run_context` class methods and pass it to [`TemporalAgent`][pydantic_ai.durable_exec.temporal.TemporalAgent] as `run_context_type`. ### Streaming diff --git a/docs/output.md b/docs/output.md index 85bfbddd1c..182a753944 100644 --- a/docs/output.md +++ b/docs/output.md @@ -470,6 +470,40 @@ print(result.output) _(This example is complete, it can be run "as is")_ +#### Handling partial output in output validators {#partial-output} + +You can use the `partial_output` field on `RunContext` to handle validation differently for partial outputs during streaming (e.g. skip validation altogether). + +```python {title="partial_validation_streaming.py" line_length="120"} +from pydantic_ai import Agent, ModelRetry, RunContext + +agent = Agent('openai:gpt-5') + +@agent.output_validator +def validate_output(ctx: RunContext, output: str) -> str: + if ctx.partial_output: + return output + else: + if len(output) < 50: + raise ModelRetry('Output is too short.') + return output + + +async def main(): + async with agent.run_stream('Write a long story about a cat') as result: + async for message in result.stream_text(): + print(message) + #> Once upon a + #> Once upon a time, there was + #> Once upon a time, there was a curious cat + #> Once upon a time, there was a curious cat named Whiskers who + #> Once upon a time, there was a curious cat named Whiskers who loved to explore + #> Once upon a time, there was a curious cat named Whiskers who loved to explore the world around + #> Once upon a time, there was a curious cat named Whiskers who loved to explore the world around him... +``` + +_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_ + ## Image output Some models can generate images as part of their response, for example those that support the [Image Generation built-in tool](builtin-tools.md#image-generation-tool) and OpenAI models using the [Code Execution built-in tool](builtin-tools.md#code-execution-tool) when told to generate a chart. diff --git a/pydantic_ai_slim/pydantic_ai/_run_context.py b/pydantic_ai_slim/pydantic_ai/_run_context.py index e17afd78a8..1848c42eb1 100644 --- a/pydantic_ai_slim/pydantic_ai/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/_run_context.py @@ -58,6 +58,8 @@ class RunContext(Generic[RunContextAgentDepsT]): """The current step in the run.""" tool_call_approved: bool = False """Whether a tool call that required approval has now been approved.""" + partial_output: bool = False + """Whether the output passed to an output validator is partial.""" @property def last_attempt(self) -> bool: diff --git a/pydantic_ai_slim/pydantic_ai/_tool_manager.py b/pydantic_ai_slim/pydantic_ai/_tool_manager.py index a5546a4e01..6774d7f8c3 100644 --- a/pydantic_ai_slim/pydantic_ai/_tool_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_tool_manager.py @@ -147,6 +147,7 @@ async def _call_tool( tool_call_id=call.tool_call_id, retry=self.ctx.retries.get(name, 0), max_retries=tool.max_retries, + partial_output=allow_partial, ) pyd_allow_partial = 'trailing-strings' if allow_partial else 'off' diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_run_context.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_run_context.py index fa307dd68b..c24587553d 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_run_context.py @@ -14,7 +14,7 @@ class TemporalRunContext(RunContext[AgentDepsT]): """The [`RunContext`][pydantic_ai.tools.RunContext] subclass to use to serialize and deserialize the run context for use inside a Temporal activity. - By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, `max_retries` and `run_step` attributes will be available. + By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, `max_retries`, `run_step` and `partial_output` attributes will be available. To make another attribute available, create a `TemporalRunContext` subclass with a custom `serialize_run_context` class method that returns a dictionary that includes the attribute and pass it to [`TemporalAgent`][pydantic_ai.durable_exec.temporal.TemporalAgent]. """ @@ -49,6 +49,7 @@ def serialize_run_context(cls, ctx: RunContext[Any]) -> dict[str, Any]: 'retry': ctx.retry, 'max_retries': ctx.max_retries, 'run_step': ctx.run_step, + 'partial_output': ctx.partial_output, } @classmethod diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 20967c299f..a9c92c9ce9 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -2,7 +2,7 @@ from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator from copy import deepcopy -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from datetime import datetime from typing import TYPE_CHECKING, Generic, cast, overload @@ -117,7 +117,7 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None = else: async for text in self._stream_response_text(delta=False, debounce_by=debounce_by): for validator in self._output_validators: - text = await validator.validate(text, self._run_ctx) # pragma: no cover + text = await validator.validate(text, replace(self._run_ctx, partial_output=True)) yield text # TODO (v2): Drop in favor of `response` property @@ -195,7 +195,9 @@ async def validate_response_output( text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) for validator in self._output_validators: - result_data = await validator.validate(result_data, self._run_ctx) + result_data = await validator.validate( + result_data, replace(self._run_ctx, partial_output=allow_partial) + ) return result_data else: raise exceptions.UnexpectedModelBehavior( # pragma: no cover diff --git a/tests/test_agent.py b/tests/test_agent.py index a05359b145..8834a427aa 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -3,7 +3,7 @@ import re import sys from collections import defaultdict -from collections.abc import AsyncIterable, Callable +from collections.abc import AsyncIterable, AsyncIterator, Callable from dataclasses import dataclass, replace from datetime import timezone from typing import Any, Generic, Literal, TypeVar, Union @@ -59,7 +59,7 @@ ) from pydantic_ai.agent import AgentRunResult, WrapperAgent from pydantic_ai.builtin_tools import CodeExecutionTool, MCPServerTool, WebSearchTool -from pydantic_ai.models.function import AgentInfo, FunctionModel +from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel from pydantic_ai.output import StructuredDict, ToolOutput from pydantic_ai.result import RunUsage @@ -338,6 +338,85 @@ def validate_output(ctx: RunContext[None], o: Foo) -> Foo: ) +def test_output_validator_partial_sync(): + """Test that output validators receive correct value for `partial_output` in sync mode.""" + call_log: list[tuple[str, bool]] = [] + + agent = Agent[None, str](TestModel(custom_output_text='test output')) + + @agent.output_validator + def validate_output(ctx: RunContext[None], output: str) -> str: + call_log.append((output, ctx.partial_output)) + return output + + result = agent.run_sync('Hello') + assert result.output == 'test output' + + assert call_log == snapshot([('test output', False)]) + + +async def test_output_validator_partial_stream_text(): + """Test that output validators receive correct value for `partial_output` when using stream_text().""" + call_log: list[tuple[str, bool]] = [] + + async def stream_text(messages: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str]: + for chunk in ['Hello', ' ', 'world', '!']: + yield chunk + + agent = Agent(FunctionModel(stream_function=stream_text)) + + @agent.output_validator + def validate_output(ctx: RunContext[None], output: str) -> str: + call_log.append((output, ctx.partial_output)) + return output + + async with agent.run_stream('Hello') as result: + text_parts = [] + async for chunk in result.stream_text(debounce_by=None): + text_parts.append(chunk) + + assert text_parts[-1] == 'Hello world!' + assert call_log == snapshot( + [ + ('Hello', True), + ('Hello ', True), + ('Hello world', True), + ('Hello world!', True), + ('Hello world!', False), + ] + ) + + +async def test_output_validator_partial_stream_output(): + """Test that output validators receive correct value for `partial_output` when using stream_output().""" + call_log: list[tuple[Foo, bool]] = [] + + async def stream_model(messages: list[ModelMessage], info: AgentInfo) -> AsyncIterator[DeltaToolCalls]: + assert info.output_tools is not None + yield {0: DeltaToolCall(name=info.output_tools[0].name, json_args='{"a": 42')} + yield {0: DeltaToolCall(json_args=', "b": "f')} + yield {0: DeltaToolCall(json_args='oo"}')} + + agent = Agent(FunctionModel(stream_function=stream_model), output_type=Foo) + + @agent.output_validator + def validate_output(ctx: RunContext[None], output: Foo) -> Foo: + call_log.append((output, ctx.partial_output)) + return output + + async with agent.run_stream('Hello') as result: + outputs = [output async for output in result.stream_output(debounce_by=None)] + + assert outputs[-1] == Foo(a=42, b='foo') + assert call_log == snapshot( + [ + (Foo(a=42, b='f'), True), + (Foo(a=42, b='foo'), True), + (Foo(a=42, b='foo'), False), + ] + ) + + def test_plain_response_then_tuple(): call_index = 0 diff --git a/tests/test_examples.py b/tests/test_examples.py index 504b438cf0..a8d0d33095 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -522,6 +522,7 @@ async def call_tool( 'Where do I live?': 'You live in Mexico City.', 'Tell me about the pydantic/pydantic-ai repo.': 'The pydantic/pydantic-ai repo is a Python agent framework for building Generative AI applications.', 'What do I have on my calendar today?': "You're going to spend all day playing with Pydantic AI.", + 'Write a long story about a cat': 'Once upon a time, there was a curious cat named Whiskers who loved to explore the world around him...', } tool_responses: dict[tuple[str, str], str] = {