|
3 | 3 | import re |
4 | 4 | import sys |
5 | 5 | from collections import defaultdict |
6 | | -from collections.abc import AsyncIterable, Callable |
| 6 | +from collections.abc import AsyncIterable, AsyncIterator, Callable |
7 | 7 | from dataclasses import dataclass, replace |
8 | 8 | from datetime import timezone |
9 | 9 | from typing import Any, Generic, Literal, TypeVar, Union |
|
59 | 59 | ) |
60 | 60 | from pydantic_ai.agent import AgentRunResult, WrapperAgent |
61 | 61 | from pydantic_ai.builtin_tools import CodeExecutionTool, MCPServerTool, WebSearchTool |
62 | | -from pydantic_ai.models.function import AgentInfo, FunctionModel |
| 62 | +from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel |
63 | 63 | from pydantic_ai.models.test import TestModel |
64 | 64 | from pydantic_ai.output import StructuredDict, ToolOutput |
65 | 65 | from pydantic_ai.result import RunUsage |
@@ -338,6 +338,85 @@ def validate_output(ctx: RunContext[None], o: Foo) -> Foo: |
338 | 338 | ) |
339 | 339 |
|
340 | 340 |
|
| 341 | +def test_output_validator_partial_sync(): |
| 342 | + """Test that output validators receive correct value for `partial_output` in sync mode.""" |
| 343 | + call_log: list[tuple[str, bool]] = [] |
| 344 | + |
| 345 | + agent = Agent[None, str](TestModel(custom_output_text='test output')) |
| 346 | + |
| 347 | + @agent.output_validator |
| 348 | + def validate_output(ctx: RunContext[None], output: str) -> str: |
| 349 | + call_log.append((output, ctx.partial_output)) |
| 350 | + return output |
| 351 | + |
| 352 | + result = agent.run_sync('Hello') |
| 353 | + assert result.output == 'test output' |
| 354 | + |
| 355 | + assert call_log == snapshot([('test output', False)]) |
| 356 | + |
| 357 | + |
| 358 | +async def test_output_validator_partial_stream_text(): |
| 359 | + """Test that output validators receive correct value for `partial_output` when using stream_text().""" |
| 360 | + call_log: list[tuple[str, bool]] = [] |
| 361 | + |
| 362 | + async def stream_text(messages: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str]: |
| 363 | + for chunk in ['Hello', ' ', 'world', '!']: |
| 364 | + yield chunk |
| 365 | + |
| 366 | + agent = Agent(FunctionModel(stream_function=stream_text)) |
| 367 | + |
| 368 | + @agent.output_validator |
| 369 | + def validate_output(ctx: RunContext[None], output: str) -> str: |
| 370 | + call_log.append((output, ctx.partial_output)) |
| 371 | + return output |
| 372 | + |
| 373 | + async with agent.run_stream('Hello') as result: |
| 374 | + text_parts = [] |
| 375 | + async for chunk in result.stream_text(debounce_by=None): |
| 376 | + text_parts.append(chunk) |
| 377 | + |
| 378 | + assert text_parts[-1] == 'Hello world!' |
| 379 | + assert call_log == snapshot( |
| 380 | + [ |
| 381 | + ('Hello', True), |
| 382 | + ('Hello ', True), |
| 383 | + ('Hello world', True), |
| 384 | + ('Hello world!', True), |
| 385 | + ('Hello world!', False), |
| 386 | + ] |
| 387 | + ) |
| 388 | + |
| 389 | + |
| 390 | +async def test_output_validator_partial_stream_output(): |
| 391 | + """Test that output validators receive correct value for `partial_output` when using stream_output().""" |
| 392 | + call_log: list[tuple[Foo, bool]] = [] |
| 393 | + |
| 394 | + async def stream_model(messages: list[ModelMessage], info: AgentInfo) -> AsyncIterator[DeltaToolCalls]: |
| 395 | + assert info.output_tools is not None |
| 396 | + yield {0: DeltaToolCall(name=info.output_tools[0].name, json_args='{"a": 42')} |
| 397 | + yield {0: DeltaToolCall(json_args=', "b": "f')} |
| 398 | + yield {0: DeltaToolCall(json_args='oo"}')} |
| 399 | + |
| 400 | + agent = Agent(FunctionModel(stream_function=stream_model), output_type=Foo) |
| 401 | + |
| 402 | + @agent.output_validator |
| 403 | + def validate_output(ctx: RunContext[None], output: Foo) -> Foo: |
| 404 | + call_log.append((output, ctx.partial_output)) |
| 405 | + return output |
| 406 | + |
| 407 | + async with agent.run_stream('Hello') as result: |
| 408 | + outputs = [output async for output in result.stream_output(debounce_by=None)] |
| 409 | + |
| 410 | + assert outputs[-1] == Foo(a=42, b='foo') |
| 411 | + assert call_log == snapshot( |
| 412 | + [ |
| 413 | + (Foo(a=42, b='f'), True), |
| 414 | + (Foo(a=42, b='foo'), True), |
| 415 | + (Foo(a=42, b='foo'), False), |
| 416 | + ] |
| 417 | + ) |
| 418 | + |
| 419 | + |
341 | 420 | def test_plain_response_then_tuple(): |
342 | 421 | call_index = 0 |
343 | 422 |
|
|
0 commit comments