Skip to content

Commit 6a153c8

Browse files
petersliDouweM
andauthored
Add partial_output to RunContext (#3286)
Co-authored-by: Douwe Maan <douwe@pydantic.dev>
1 parent d8c0526 commit 6a153c8

File tree

8 files changed

+127
-7
lines changed

8 files changed

+127
-7
lines changed

docs/durable_execution/temporal.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ As workflows and activities run in separate processes, any values passed between
172172

173173
To account for these limitations, tool functions and the [event stream handler](#streaming) running inside activities receive a limited version of the agent's [`RunContext`][pydantic_ai.tools.RunContext], and it's your responsibility to make sure that the [dependencies](../dependencies.md) object provided to [`TemporalAgent.run()`][pydantic_ai.durable_exec.temporal.TemporalAgent.run] can be serialized using Pydantic.
174174

175-
Specifically, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, `max_retries` and `run_step` fields are available by default, and trying to access `model`, `usage`, `prompt`, `messages`, or `tracer` will raise an error.
175+
Specifically, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, `max_retries`, `run_step` and `partial_output` fields are available by default, and trying to access `model`, `usage`, `prompt`, `messages`, or `tracer` will raise an error.
176176
If you need one or more of these attributes to be available inside activities, you can create a [`TemporalRunContext`][pydantic_ai.durable_exec.temporal.TemporalRunContext] subclass with custom `serialize_run_context` and `deserialize_run_context` class methods and pass it to [`TemporalAgent`][pydantic_ai.durable_exec.temporal.TemporalAgent] as `run_context_type`.
177177

178178
### Streaming

docs/output.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,40 @@ print(result.output)
470470

471471
_(This example is complete, it can be run "as is")_
472472

473+
#### Handling partial output in output validators {#partial-output}
474+
475+
You can use the `partial_output` field on `RunContext` to handle validation differently for partial outputs during streaming (e.g. skip validation altogether).
476+
477+
```python {title="partial_validation_streaming.py" line_length="120"}
478+
from pydantic_ai import Agent, ModelRetry, RunContext
479+
480+
agent = Agent('openai:gpt-5')
481+
482+
@agent.output_validator
483+
def validate_output(ctx: RunContext, output: str) -> str:
484+
if ctx.partial_output:
485+
return output
486+
else:
487+
if len(output) < 50:
488+
raise ModelRetry('Output is too short.')
489+
return output
490+
491+
492+
async def main():
493+
async with agent.run_stream('Write a long story about a cat') as result:
494+
async for message in result.stream_text():
495+
print(message)
496+
#> Once upon a
497+
#> Once upon a time, there was
498+
#> Once upon a time, there was a curious cat
499+
#> Once upon a time, there was a curious cat named Whiskers who
500+
#> Once upon a time, there was a curious cat named Whiskers who loved to explore
501+
#> Once upon a time, there was a curious cat named Whiskers who loved to explore the world around
502+
#> Once upon a time, there was a curious cat named Whiskers who loved to explore the world around him...
503+
```
504+
505+
_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_
506+
473507
## Image output
474508

475509
Some models can generate images as part of their response, for example those that support the [Image Generation built-in tool](builtin-tools.md#image-generation-tool) and OpenAI models using the [Code Execution built-in tool](builtin-tools.md#code-execution-tool) when told to generate a chart.

pydantic_ai_slim/pydantic_ai/_run_context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class RunContext(Generic[RunContextAgentDepsT]):
5858
"""The current step in the run."""
5959
tool_call_approved: bool = False
6060
"""Whether a tool call that required approval has now been approved."""
61+
partial_output: bool = False
62+
"""Whether the output passed to an output validator is partial."""
6163

6264
@property
6365
def last_attempt(self) -> bool:

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ async def _call_tool(
147147
tool_call_id=call.tool_call_id,
148148
retry=self.ctx.retries.get(name, 0),
149149
max_retries=tool.max_retries,
150+
partial_output=allow_partial,
150151
)
151152

152153
pyd_allow_partial = 'trailing-strings' if allow_partial else 'off'

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_run_context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
class TemporalRunContext(RunContext[AgentDepsT]):
1515
"""The [`RunContext`][pydantic_ai.tools.RunContext] subclass to use to serialize and deserialize the run context for use inside a Temporal activity.
1616
17-
By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, `max_retries` and `run_step` attributes will be available.
17+
By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, `max_retries`, `run_step` and `partial_output` attributes will be available.
1818
To make another attribute available, create a `TemporalRunContext` subclass with a custom `serialize_run_context` class method that returns a dictionary that includes the attribute and pass it to [`TemporalAgent`][pydantic_ai.durable_exec.temporal.TemporalAgent].
1919
"""
2020

@@ -49,6 +49,7 @@ def serialize_run_context(cls, ctx: RunContext[Any]) -> dict[str, Any]:
4949
'retry': ctx.retry,
5050
'max_retries': ctx.max_retries,
5151
'run_step': ctx.run_step,
52+
'partial_output': ctx.partial_output,
5253
}
5354

5455
@classmethod

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator
44
from copy import deepcopy
5-
from dataclasses import dataclass, field
5+
from dataclasses import dataclass, field, replace
66
from datetime import datetime
77
from typing import TYPE_CHECKING, Generic, cast, overload
88

@@ -117,7 +117,7 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None =
117117
else:
118118
async for text in self._stream_response_text(delta=False, debounce_by=debounce_by):
119119
for validator in self._output_validators:
120-
text = await validator.validate(text, self._run_ctx) # pragma: no cover
120+
text = await validator.validate(text, replace(self._run_ctx, partial_output=True))
121121
yield text
122122

123123
# TODO (v2): Drop in favor of `response` property
@@ -195,7 +195,9 @@ async def validate_response_output(
195195
text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
196196
)
197197
for validator in self._output_validators:
198-
result_data = await validator.validate(result_data, self._run_ctx)
198+
result_data = await validator.validate(
199+
result_data, replace(self._run_ctx, partial_output=allow_partial)
200+
)
199201
return result_data
200202
else:
201203
raise exceptions.UnexpectedModelBehavior( # pragma: no cover

tests/test_agent.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import re
44
import sys
55
from collections import defaultdict
6-
from collections.abc import AsyncIterable, Callable
6+
from collections.abc import AsyncIterable, AsyncIterator, Callable
77
from dataclasses import dataclass, replace
88
from datetime import timezone
99
from typing import Any, Generic, Literal, TypeVar, Union
@@ -59,7 +59,7 @@
5959
)
6060
from pydantic_ai.agent import AgentRunResult, WrapperAgent
6161
from pydantic_ai.builtin_tools import CodeExecutionTool, MCPServerTool, WebSearchTool
62-
from pydantic_ai.models.function import AgentInfo, FunctionModel
62+
from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel
6363
from pydantic_ai.models.test import TestModel
6464
from pydantic_ai.output import StructuredDict, ToolOutput
6565
from pydantic_ai.result import RunUsage
@@ -338,6 +338,85 @@ def validate_output(ctx: RunContext[None], o: Foo) -> Foo:
338338
)
339339

340340

341+
def test_output_validator_partial_sync():
342+
"""Test that output validators receive correct value for `partial_output` in sync mode."""
343+
call_log: list[tuple[str, bool]] = []
344+
345+
agent = Agent[None, str](TestModel(custom_output_text='test output'))
346+
347+
@agent.output_validator
348+
def validate_output(ctx: RunContext[None], output: str) -> str:
349+
call_log.append((output, ctx.partial_output))
350+
return output
351+
352+
result = agent.run_sync('Hello')
353+
assert result.output == 'test output'
354+
355+
assert call_log == snapshot([('test output', False)])
356+
357+
358+
async def test_output_validator_partial_stream_text():
359+
"""Test that output validators receive correct value for `partial_output` when using stream_text()."""
360+
call_log: list[tuple[str, bool]] = []
361+
362+
async def stream_text(messages: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str]:
363+
for chunk in ['Hello', ' ', 'world', '!']:
364+
yield chunk
365+
366+
agent = Agent(FunctionModel(stream_function=stream_text))
367+
368+
@agent.output_validator
369+
def validate_output(ctx: RunContext[None], output: str) -> str:
370+
call_log.append((output, ctx.partial_output))
371+
return output
372+
373+
async with agent.run_stream('Hello') as result:
374+
text_parts = []
375+
async for chunk in result.stream_text(debounce_by=None):
376+
text_parts.append(chunk)
377+
378+
assert text_parts[-1] == 'Hello world!'
379+
assert call_log == snapshot(
380+
[
381+
('Hello', True),
382+
('Hello ', True),
383+
('Hello world', True),
384+
('Hello world!', True),
385+
('Hello world!', False),
386+
]
387+
)
388+
389+
390+
async def test_output_validator_partial_stream_output():
391+
"""Test that output validators receive correct value for `partial_output` when using stream_output()."""
392+
call_log: list[tuple[Foo, bool]] = []
393+
394+
async def stream_model(messages: list[ModelMessage], info: AgentInfo) -> AsyncIterator[DeltaToolCalls]:
395+
assert info.output_tools is not None
396+
yield {0: DeltaToolCall(name=info.output_tools[0].name, json_args='{"a": 42')}
397+
yield {0: DeltaToolCall(json_args=', "b": "f')}
398+
yield {0: DeltaToolCall(json_args='oo"}')}
399+
400+
agent = Agent(FunctionModel(stream_function=stream_model), output_type=Foo)
401+
402+
@agent.output_validator
403+
def validate_output(ctx: RunContext[None], output: Foo) -> Foo:
404+
call_log.append((output, ctx.partial_output))
405+
return output
406+
407+
async with agent.run_stream('Hello') as result:
408+
outputs = [output async for output in result.stream_output(debounce_by=None)]
409+
410+
assert outputs[-1] == Foo(a=42, b='foo')
411+
assert call_log == snapshot(
412+
[
413+
(Foo(a=42, b='f'), True),
414+
(Foo(a=42, b='foo'), True),
415+
(Foo(a=42, b='foo'), False),
416+
]
417+
)
418+
419+
341420
def test_plain_response_then_tuple():
342421
call_index = 0
343422

tests/test_examples.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,7 @@ async def call_tool(
522522
'Where do I live?': 'You live in Mexico City.',
523523
'Tell me about the pydantic/pydantic-ai repo.': 'The pydantic/pydantic-ai repo is a Python agent framework for building Generative AI applications.',
524524
'What do I have on my calendar today?': "You're going to spend all day playing with Pydantic AI.",
525+
'Write a long story about a cat': 'Once upon a time, there was a curious cat named Whiskers who loved to explore the world around him...',
525526
}
526527

527528
tool_responses: dict[tuple[str, str], str] = {

0 commit comments

Comments
 (0)