From 9c8ec45d0a92d05d0c694399055800e3b7a33a8d Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 28 Nov 2024 15:44:21 +0000 Subject: [PATCH 1/4] adding docs for testing and evals --- docs/api/agent.md | 3 +- docs/api/models/test.md | 2 + docs/dependencies.md | 4 +- docs/testing-evals.md | 186 +++++++++++++++++++- mkdocs.yml | 1 + pydantic_ai_slim/pydantic_ai/_utils.py | 14 +- pydantic_ai_slim/pydantic_ai/agent.py | 53 +++--- pydantic_ai_slim/pydantic_ai/models/test.py | 14 +- tests/test_agent.py | 6 +- tests/test_deps.py | 4 +- tests/test_examples.py | 37 +++- tests/typed_agent.py | 4 +- 12 files changed, 278 insertions(+), 50 deletions(-) diff --git a/docs/api/agent.md b/docs/api/agent.md index 6b012f1e..085abe6c 100644 --- a/docs/api/agent.md +++ b/docs/api/agent.md @@ -8,8 +8,7 @@ - run_sync - run_stream - model - - override_deps - - override_model + - override - last_run_messages - system_prompt - tool diff --git a/docs/api/models/test.md b/docs/api/models/test.md index 35ffc19d..79fbad1a 100644 --- a/docs/api/models/test.md +++ b/docs/api/models/test.md @@ -1,3 +1,5 @@ # `pydantic_ai.models.test` +Utility model for quickly testing apps built with PydanticAI. + ::: pydantic_ai.models.test diff --git a/docs/dependencies.md b/docs/dependencies.md index 81c067fc..70042124 100644 --- a/docs/dependencies.md +++ b/docs/dependencies.md @@ -231,7 +231,7 @@ When testing agents, it's useful to be able to customise dependencies. While this can sometimes be done by calling the agent directly within unit tests, we can also override dependencies while calling application code which in turn calls the agent. -This is done via the [`override_deps`][pydantic_ai.Agent.override_deps] method on the agent. +This is done via the [`override`][pydantic_ai.Agent.override] method on the agent. ```py title="joke_app.py" from dataclasses import dataclass @@ -286,7 +286,7 @@ class TestMyDeps(MyDeps): # (1)! async def test_application_code(): test_deps = TestMyDeps('test_key', None) # (2)! - with joke_agent.override_deps(test_deps): # (3)! + with joke_agent.override(deps=test_deps): # (3)! joke = await application_code('Tell me a joke.') # (4)! assert joke.startswith('Did you hear about the toothpaste scandal?') ``` diff --git a/docs/testing-evals.md b/docs/testing-evals.md index a6165abf..2c760b50 100644 --- a/docs/testing-evals.md +++ b/docs/testing-evals.md @@ -1,8 +1,188 @@ +from black import timezonefrom black import timezonefrom black import timezone + # Testing and Evals +When thinking about PydanticAI use and LLM integrations in general, there are two distinct kinds of test: + +1. **Unit tests** — tests of your application code, and whether it's behaving correctly +2. **"Evals"** — tests of the LLM, and how good or bad its responses are + +For the most part, these two kinds of tests have pretty separate goals and considerations. + +## Unit tests + +Unit tests for PydanticAI code are just like unit tests for any other Python code. + +Because for the most part they're nothing new, we have pretty well established tools and patterns for writing and running these kinds of tests. + +Unless you're really sure you know better, you'll probably want to follow roughly this strategy: + +* Use [`pytest`](https://docs.pytest.org/en/stable/) as your test harness +* If you find yourself typing out long assertions, use [`inline-snapshot`](https://15r10nk.github.io/inline-snapshot/latest/) +* Similarly, [dirty-equals](https://dirty-equals.helpmanual.io/latest/) can be useful for comparing large data structures +* Use [`TestModel`][pydantic_ai.models.test.TestModel] or [`FunctionModel`][pydantic_ai.models.function.FunctionModel] in place of your actual model to avoid the cost, latency and variability of real LLM calls +* Use [`Agent.override`][pydantic_ai.agent.Agent.override] to replace your model inside your application logic. +* Set [`ALLOW_MODEL_REQUESTS=False`][pydantic_ai.models.ALLOW_MODEL_REQUESTS] globally to block any requests from being made to non-test models + +### Unit testing with `TestModel` + +The simplest and fastest way to exercise most of your application code is using [`TestModel`][pydantic_ai.models.test.TestModel], this will (by default) call all tools in the agent, then return either plain text or a structured response depending on the return type of the agent. + +!!! note "`TestModel` is not magic" + The "clever" (but not too clever) part of `TestModel` is that it will attempt to generate valid structured data for [function tools](agents.md#function-tools) and [result types](results.md#structured-result-validation) based on the schema of the registered tools. + + There's no ML or AI in `TestModel`, it's just plain old procedural Python code that tries to generate data that satisfies the JSON schema of a tool. + + The resulting data won't look pretty or relevant, but it should pass Pydantic's validation in most cases. + If you want something more sophisticated, use [`FunctionModel`][pydantic_ai.models.function.FunctionModel] and write your own data generation logic. + +Let's consider the following application code: + +```py title="weather_app.py" +import asyncio +from datetime import date + +from pydantic_ai import Agent, CallContext + +from fake_database import DatabaseConn # (1)! +from weather_service import WeatherService # (2)! + +weather_agent = Agent( + 'openai:gpt-4o', + deps_type=WeatherService, + system_prompt='Providing a weather forecast at the locations the user provides.', +) + + +@weather_agent.tool +def weather_forecast( + ctx: CallContext[WeatherService], location: str, forecast_date: date +) -> str: + if forecast_date < date.today(): + # (pretend) use the cheaper endpoint to get historical data + return ctx.deps.get_historic_weather(location, forecast_date) + else: + return ctx.deps.get_forecast(location, forecast_date) + + +async def run_weather_forecast( # (3)! + user_prompts: list[tuple[str, int]], conn: DatabaseConn +): + """Run weather forecast for a list of user prompts. + + Args: + user_prompts: A list of tuples containing the user prompt and user id. + conn: A database connection to store the forecast results. + """ + async with WeatherService() as weather_service: + + async def run_forecast(prompt: str, user_id: int): + result = await weather_agent.run(prompt, deps=weather_service) + await conn.store_forecast(user_id, result.data) + + # run all prompts in parallel + await asyncio.gather( + *(run_forecast(prompt, user_id) for (prompt, user_id) in user_prompts) + ) +``` + +1. `DatabaseConn` is a class that holds a database connection +2. `WeatherService` is a class that provides weather data +3. This function is the code we want to test, together with the agent it uses + +Here we have a function that takes a list of `#!python (user_prompt, user_id)` tuples, gets a weather forecast for each prompt, and stores the result in the database. + +We want to test this code without having to mock certain objects or modify our code so we can pass test objects in. + +Here's how we would write tests using [`TestModel`][pydantic_ai.models.test.TestModel]: + +```py title="test_weather_app.py" +from datetime import timezone +import pytest + +from dirty_equals import IsNow + +from pydantic_ai import models +from pydantic_ai.models.test import TestModel +from pydantic_ai.messages import ( + SystemPrompt, + UserPrompt, + ModelStructuredResponse, + ToolCall, + ArgsObject, + ToolReturn, + ModelTextResponse, +) + +from fake_database import DatabaseConn +from weather_app import run_weather_forecast, weather_agent + +pytestmark = pytest.mark.anyio # (1)! +models.ALLOW_MODEL_REQUESTS = False # (2)! + + +async def test_forecast_success(): + conn = DatabaseConn() + user_id = 1 + with weather_agent.override(model=TestModel()): # (3)! + prompt = 'What will the weather be like in London on 2024-11-28?' + await run_weather_forecast([(prompt, user_id)], conn) # (4)! + + forecast = await conn.get_forecast(user_id) + assert forecast == '{"weather_forecast":"Sunny with a chance of rain"}' # (5)! + + assert weather_agent.last_run_messages == [ # (6)! + SystemPrompt( + content='Providing a weather forecast at the locations the user provides.', + role='system', + ), + UserPrompt( + content='What will the weather be like in London on 2024-11-28?', + timestamp=IsNow(tz=timezone.utc), # (7)! + role='user', + ), + ModelStructuredResponse( + calls=[ + ToolCall( + tool_name='weather_forecast', + args=ArgsObject( + args_object={'location': 'a', 'forecast_date': '2024-01-01'} + ), + tool_id=None, + ) + ], + timestamp=IsNow(tz=timezone.utc), + role='model-structured-response', + ), + ToolReturn( + tool_name='weather_forecast', + content='Sunny with a chance of rain', + tool_id=None, + timestamp=IsNow(tz=timezone.utc), + role='tool-return', + ), + ModelTextResponse( + content='{"weather_forecast":"Sunny with a chance of rain"}', + timestamp=IsNow(tz=timezone.utc), + role='model-text-response', + ), + ] +``` + +1. We're using [anyio](https://anyio.readthedocs.io/en/stable/) to run async tests. +2. This is a safety measure to make sure we don't accidentally make real requests to the LLM while testing. +3. We're using [`override`][pydantic_ai.agent.Agent.override] to replace the agent's model with [`TestModel`][pydantic_ai.models.test.TestModel]. +4. Now we call the function we want to test. +5. But default, `TestModel` will return a JSON string summarising the tools calls made, and what was returned. If you wanted to customise the response to something more closely aligned with the domain, you could add [`custom_result_text='Sunny'`][pydantic_ai.models.test.TestModel.custom_result_text] when defining `TestModel`. +6. So far we don't actually know which tools were called and with which values, we can use the [`last_run_messages`][pydantic_ai.agent.Agent.last_run_messages] attribute to inspect messages from the most recent run and assert the exchange between the agent and the model occurred as expected. +7. The [`IsNow`][dirty_equals.IsNow] helper allows us to use declarative asserts even with data which will contain timestamps that change over time. + +### Unit testing with `FunctionModel` + TODO -principles: +## Evals + +TODO. -* unit tests are no different to any other app, just `TestModel` or `FunctionModel`, we know how to do unit tests, there's no magic just good practice -* evals are more like benchmarks, they never "pass" although they do "fail", you care mostly about how they change over time, we (and we think most other people) don't really know what a "good" eval is, we provide some useful tools, we'll improve this if/when a common best practice emerges, or we think we have something interesting to say +evals are more like benchmarks, they never "pass" although they do "fail", you care mostly about how they change over time, we (and we think most other people) don't really know what a "good" eval is, we provide some useful tools, we'll improve this if/when a common best practice emerges, or we think we have something interesting to say diff --git a/mkdocs.yml b/mkdocs.yml index c4ebc14a..589b1d26 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -155,6 +155,7 @@ plugins: import: - url: https://docs.python.org/3/objects.inv - url: https://docs.pydantic.dev/latest/objects.inv + - url: https://dirty-equals.helpmanual.io/latest/objects.inv - url: https://fastapi.tiangolo.com/objects.inv - url: https://typing-extensions.readthedocs.io/en/latest/objects.inv - url: https://rich.readthedocs.io/en/stable/objects.inv diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index d81dafe6..daa59874 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -8,7 +8,7 @@ from datetime import datetime, timezone from functools import partial from types import GenericAlias -from typing import Any, Callable, Generic, TypeVar, Union, cast, overload +from typing import Any, Callable, Generic, TypeGuard, TypeVar, Union, cast, overload from pydantic import BaseModel from pydantic.json_schema import JsonSchemaValue @@ -66,10 +66,6 @@ class Some(Generic[T]): """Analogous to Rust's `Option` type, usage: `Option[Thing]` is equivalent to `Some[Thing] | None`.""" -Left = TypeVar('Left') -Right = TypeVar('Right') - - class Unset: """A singleton to represent an unset value.""" @@ -79,6 +75,14 @@ class Unset: UNSET = Unset() +def is_set(t_or_unset: T | Unset) -> TypeGuard[T]: + return t_or_unset is not UNSET + + +Left = TypeVar('Left') +Right = TypeVar('Right') + + class Either(Generic[Left, Right]): """Two member Union that records which member was set, this is analogous to Rust enums with two variants. diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index fbb35bcf..1e4d043a 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -105,7 +105,7 @@ def __init__( it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately, which checks for the necessary environment variables. Set this to `false` to defer the evaluation until the first run. Useful if you want to - [override the model][pydantic_ai.Agent.override_model] for testing. + [override the model][pydantic_ai.Agent.override] for testing. """ if model is None or defer_model_check: self.model = model @@ -296,32 +296,41 @@ async def run_stream( cost += model_response.cost() @contextmanager - def override_deps(self, overriding_deps: AgentDeps) -> Iterator[None]: - """Context manager to temporarily override agent dependencies, this is particularly useful when testing. + def override( + self, + *, + deps: AgentDeps | _utils.Unset = _utils.UNSET, + model: models.Model | models.KnownModelName | _utils.Unset = _utils.UNSET, + ) -> Iterator[None]: + """Context manager to temporarily override agent dependencies and model. + + This is particularly useful when testing. Args: - overriding_deps: The dependencies to use instead of the dependencies passed to the agent run. + deps: The dependencies to use instead of the dependencies passed to the agent run. + model: The model to use instead of the model passed to the agent run. """ - override_deps_before = self._override_deps - self._override_deps = _utils.Some(overriding_deps) - try: - yield - finally: - self._override_deps = override_deps_before + if _utils.is_set(deps): + override_deps_before = self._override_deps + self._override_deps = _utils.Some(deps) + else: + override_deps_before = _utils.UNSET - @contextmanager - def override_model(self, overriding_model: models.Model | models.KnownModelName) -> Iterator[None]: - """Context manager to temporarily override the model used by the agent. + # noinspection PyTypeChecker + if _utils.is_set(model): + override_model_before = self._override_model + # noinspection PyTypeChecker + self._override_model = _utils.Some(models.infer_model(model)) # pyright: ignore[reportArgumentType] + else: + override_model_before = _utils.UNSET - Args: - overriding_model: The model to use instead of the model passed to the agent run. - """ - override_model_before = self._override_model - self._override_model = _utils.Some(models.infer_model(overriding_model)) try: yield finally: - self._override_model = override_model_before + if _utils.is_set(override_deps_before): + self._override_deps = override_deps_before + if _utils.is_set(override_model_before): + self._override_model = override_model_before @overload def system_prompt( @@ -575,11 +584,11 @@ async def _get_agent_model( """ model_: models.Model if some_model := self._override_model: - # we don't want `override_model()` to cover up errors from the model not being defined, hence this check + # we don't want `override()` to cover up errors from the model not being defined, hence this check if model is None and self.model is None: raise exceptions.UserError( '`model` must be set either when creating the agent or when calling it. ' - '(Even when `override_model()` is customizing the model that will actually be called)' + '(Even when `override(model=...)` is customizing the model that will actually be called)' ) model_ = some_model.value custom_model = None @@ -767,7 +776,7 @@ def _unknown_tool(self, tool_name: str) -> _messages.RetryPrompt: def _get_deps(self, deps: AgentDeps) -> AgentDeps: """Get deps for a run. - If we've overridden deps via `_override_deps_stack`, use that, otherwise use the deps passed to the call. + If we've overridden deps via `_override_deps`, use that, otherwise use the deps passed to the call. We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope. """ diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index f8a1d548..a3f31620 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -1,5 +1,3 @@ -"""Utilities for testing apps built with PydanticAI.""" - from __future__ import annotations as _annotations import re @@ -45,10 +43,10 @@ def __repr__(self): class TestModel(Model): """A model specifically for testing purposes. - This will (by default) call all tools in the agent model, then return a tool response if possible, + This will (by default) call all tools in the agent, then return a tool response if possible, otherwise a plain response. - How useful this function will be is unknown, it may be useless, it may require significant changes to be useful. + How useful this model is will vary significantly. Apart from `__init__` derived by the `dataclass` decorator, all methods are private or match those of the base class. @@ -320,8 +318,12 @@ def _str_gen(self, schema: dict[str, Any]) -> str: if schema.get('maxLength') == 0: return '' - else: - return self._char() + + if fmt := schema.get('format'): + if fmt == 'date': + return '2024-01-01' + + return self._char() def _int_gen(self, schema: dict[str, Any]) -> int: """Generate an integer from a JSON Schema integer.""" diff --git a/tests/test_agent.py b/tests/test_agent.py index 7b74a833..53b92e33 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -507,7 +507,7 @@ def test_override_model(env: TestEnv): env.set('GEMINI_API_KEY', 'foobar') agent = Agent('gemini-1.5-flash', result_type=tuple[int, str], defer_model_check=True) - with agent.override_model('test'): + with agent.override(model='test'): result = agent.run_sync('Hello') assert result.data == snapshot((0, 'a')) @@ -515,6 +515,6 @@ def test_override_model(env: TestEnv): def test_override_model_no_model(): agent = Agent() - with pytest.raises(UserError, match=r'`model` must be set either.+Even when `override_model\(\)` is customizing'): - with agent.override_model('test'): + with pytest.raises(UserError, match=r'`model` must be set either.+Even when `override\(model=...\)` is customiz'): + with agent.override(model='test'): agent.run_sync('Hello') diff --git a/tests/test_deps.py b/tests/test_deps.py index c52f3e03..7a89d73b 100644 --- a/tests/test_deps.py +++ b/tests/test_deps.py @@ -24,11 +24,11 @@ def test_deps_used(): def test_deps_override(): - with agent.override_deps(MyDeps(foo=3, bar=4)): + with agent.override(deps=MyDeps(foo=3, bar=4)): result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2)) assert result.data == '{"example_tool":"MyDeps(foo=3, bar=4)"}' - with agent.override_deps(MyDeps(foo=5, bar=6)): + with agent.override(deps=MyDeps(foo=5, bar=6)): result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2)) assert result.data == '{"example_tool":"MyDeps(foo=5, bar=6)"}' diff --git a/tests/test_examples.py b/tests/test_examples.py index 0825ce82..56851091 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -4,6 +4,7 @@ import sys from collections.abc import AsyncIterator, Iterable from dataclasses import dataclass, field +from datetime import date from types import ModuleType from typing import Any @@ -48,10 +49,17 @@ def get(self, name: str) -> int | None: @dataclass class DatabaseConn: users: FakeTable = field(default_factory=FakeTable) + _forecasts: dict[int, str] = field(default_factory=dict) async def execute(self, query: str) -> None: pass + async def store_forecast(self, user_id: int, forecast: str) -> None: + self._forecasts[user_id] = forecast + + async def get_forecast(self, user_id: int) -> str | None: + return self._forecasts.get(user_id) + class QueryError(RuntimeError): pass @@ -88,6 +96,26 @@ async def customer_balance(cls, *, id: int, include_pending: bool) -> float: sys.modules.pop(module_name) +@pytest.fixture(scope='module', autouse=True) +def weather_service(): + class WeatherService: + def get_historic_weather(self, location: str, forecast_date: date) -> str: + return 'Sunny with a chance of rain' + + def get_forecast(self, location: str, forecast_date: date) -> str: + return 'Rainy with a chance of sun' + + async def __aenter__(self) -> WeatherService: + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + pass + + module_name = 'weather_service' + sys.modules[module_name] = module = ModuleType(module_name) + module.__dict__.update({'WeatherService': WeatherService}) + + def find_filter_examples() -> Iterable[CodeExample]: for ex in find_examples('docs', 'pydantic_ai_slim'): if ex.path.name != '_utils.py': @@ -121,7 +149,7 @@ def test_docs_examples( ruff_ignore: list[str] = ['D'] # `from bank_database import DatabaseConn` wrongly sorted in imports # waiting for https://github.com/pydantic/pytest-examples/issues/43 - if 'from bank_database import DatabaseConn' in example.source: + if 'import DatabaseConn' in example.source: ruff_ignore.append('I001') line_length = 88 @@ -133,8 +161,10 @@ def test_docs_examples( eval_example.print_callback = print_callback call_name = 'main' - if 'def test_application_code' in example.source: - call_name = 'test_application_code' + for name in ('test_application_code', 'test_forecast_success'): + if f'def {name}' in example.source: + call_name = name + break if eval_example.update_examples: # pragma: no cover eval_example.format(example) @@ -143,6 +173,7 @@ def test_docs_examples( eval_example.lint(example) module_dict = eval_example.run_print_check(example, call=call_name) + debug(prefix_settings) if title := prefix_settings.get('title'): if title.endswith('.py'): module_name = title[:-3] diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 9295c574..e21731f3 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -123,11 +123,11 @@ async def run_stream() -> None: def run_with_override() -> None: - with typed_agent.override_deps(MyDeps(1, 2)): + with typed_agent.override(deps=MyDeps(1, 2)): typed_agent.run_sync('testing', deps=MyDeps(3, 4)) # invalid deps - with typed_agent.override_deps(123): # type: ignore[arg-type] + with typed_agent.override(deps=123): # type: ignore[arg-type] typed_agent.run_sync('testing', deps=MyDeps(3, 4)) From 16456edead0854e5602fc30ea39b09ed6f1d099b Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 28 Nov 2024 17:23:47 +0000 Subject: [PATCH 2/4] more docs and change ArgsObject -> ArgsDict --- docs/agents.md | 12 +-- docs/api/models/vertexai.md | 3 - docs/testing-evals.md | 84 +++++++++++++++++-- pydantic_ai_slim/pydantic_ai/_result.py | 2 +- pydantic_ai_slim/pydantic_ai/_tool.py | 2 +- pydantic_ai_slim/pydantic_ai/messages.py | 14 ++-- .../pydantic_ai/models/function.py | 2 +- pydantic_ai_slim/pydantic_ai/models/gemini.py | 8 +- pydantic_ai_slim/pydantic_ai/models/test.py | 12 +-- tests/models/test_gemini.py | 32 +++---- tests/models/test_model_function.py | 12 +-- tests/models/test_model_test.py | 4 +- tests/test_agent.py | 12 +-- tests/test_examples.py | 48 +++++------ tests/test_logfire.py | 4 +- tests/test_streaming.py | 6 +- 16 files changed, 158 insertions(+), 99 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index f92191b5..684835d9 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -246,9 +246,7 @@ print(dice_result.all_messages()) ), ModelStructuredResponse( calls=[ - ToolCall( - tool_name='roll_die', args=ArgsObject(args_object={}), tool_id=None - ) + ToolCall(tool_name='roll_die', args=ArgsDict(args_dict={}), tool_id=None) ], timestamp=datetime.datetime(...), role='model-structured-response', @@ -263,9 +261,7 @@ print(dice_result.all_messages()) ModelStructuredResponse( calls=[ ToolCall( - tool_name='get_player_name', - args=ArgsObject(args_object={}), - tool_id=None, + tool_name='get_player_name', args=ArgsDict(args_dict={}), tool_id=None ) ], timestamp=datetime.datetime(...), @@ -485,7 +481,7 @@ except UnexpectedModelBehavior as e: calls=[ ToolCall( tool_name='calc_volume', - args=ArgsObject(args_object={'size': 6}), + args=ArgsDict(args_dict={'size': 6}), tool_id=None, ) ], @@ -503,7 +499,7 @@ except UnexpectedModelBehavior as e: calls=[ ToolCall( tool_name='calc_volume', - args=ArgsObject(args_object={'size': 6}), + args=ArgsDict(args_dict={'size': 6}), tool_id=None, ) ], diff --git a/docs/api/models/vertexai.md b/docs/api/models/vertexai.md index 6e332ecc..cd97403d 100644 --- a/docs/api/models/vertexai.md +++ b/docs/api/models/vertexai.md @@ -10,9 +10,6 @@ and function endpoints having the same schemas as the equivalent [Gemini endpoints][pydantic_ai.models.gemini.GeminiModel]. -There are four advantages of using this API over the `generativelanguage.googleapis.com` API which -[`GeminiModel`][pydantic_ai.models.gemini.GeminiModel] uses, and one big disadvantage. - ## Setup For details on how to set up authentication with this model as well as a comparison with the `generativelanguage.googleapis.com` API used by [`GeminiModel`][pydantic_ai.models.gemini.GeminiModel], diff --git a/docs/testing-evals.md b/docs/testing-evals.md index 2c760b50..36e80165 100644 --- a/docs/testing-evals.md +++ b/docs/testing-evals.md @@ -1,5 +1,3 @@ -from black import timezonefrom black import timezonefrom black import timezone - # Testing and Evals When thinking about PydanticAI use and LLM integrations in general, there are two distinct kinds of test: @@ -109,7 +107,7 @@ from pydantic_ai.messages import ( UserPrompt, ModelStructuredResponse, ToolCall, - ArgsObject, + ArgsDict, ToolReturn, ModelTextResponse, ) @@ -121,7 +119,7 @@ pytestmark = pytest.mark.anyio # (1)! models.ALLOW_MODEL_REQUESTS = False # (2)! -async def test_forecast_success(): +async def test_forecast(): conn = DatabaseConn() user_id = 1 with weather_agent.override(model=TestModel()): # (3)! @@ -145,8 +143,8 @@ async def test_forecast_success(): calls=[ ToolCall( tool_name='weather_forecast', - args=ArgsObject( - args_object={'location': 'a', 'forecast_date': '2024-01-01'} + args=ArgsDict( + args_dict={'location': 'a', 'forecast_date': '2024-01-01'} ), tool_id=None, ) @@ -179,10 +177,78 @@ async def test_forecast_success(): ### Unit testing with `FunctionModel` -TODO +The above tests are a great start, but careful readers will notice that the `WeatherService.get_forecast` is never called since `TestModel` calls `weather_forecast` with a date in the past. + +To fully exercise `weather_forecast`, we need to use [`FunctionModel`][pydantic_ai.models.function.FunctionModel] to customise how the tools is called. + +Here's an example of using `FunctionModel` to test the `weather_forecast` tool with custom inputs + +```py title="test_weather_app2.py" +import re + +import pytest + +from pydantic_ai import models +from pydantic_ai.messages import ( + Message, + ModelAnyResponse, + ModelStructuredResponse, + ModelTextResponse, + ToolCall, +) +from pydantic_ai.models.function import AgentInfo, FunctionModel + +from fake_database import DatabaseConn +from weather_app import run_weather_forecast, weather_agent + +pytestmark = pytest.mark.anyio +models.ALLOW_MODEL_REQUESTS = False + + +async def test_forecast_future(): + def call_weather_forecast( # (1)! + messages: list[Message], info: AgentInfo + ) -> ModelAnyResponse: + if len(messages) == 2: + # first call, call the weather forecast tool + assert set(info.function_tools.keys()) == {'weather_forecast'} + + user_prompt = messages[1] + m = re.search(r'\d{4}-\d{2}-\d{2}', user_prompt.content) + assert m is not None + args = {'location': 'London', 'forecast_date': m.group()} + return ModelStructuredResponse( + calls=[ToolCall.from_dict('weather_forecast', args)] + ) + else: + # second call, return the forecast + msg = messages[-1] + assert msg.role == 'tool-return' + return ModelTextResponse(f'The forecast is: {msg.content}') + + conn = DatabaseConn() + user_id = 1 + with weather_agent.override(model=FunctionModel(call_weather_forecast)): # (2)! + prompt = 'What will the weather be like in London on 2032-01-01?' + await run_weather_forecast([(prompt, user_id)], conn) + + forecast = await conn.get_forecast(user_id) + assert forecast == 'The forecast is: Rainy with a chance of sun' +``` + +1. We define a function `call_weather_forecast` that will be called by `FunctionModel` in place of the LLM, this function has access to the list of [`Message`s][pydantic_ai.messages.Message] that make up the run, and [`AgentInfo`][pydantic_ai.models.function.AgentInfo] which contains information about the agent and the function tools and return type tools. +2. We use [`FunctionModel`][pydantic_ai.models.function.FunctionModel] to replace the agent's model with our custom function. ## Evals -TODO. +"Evals" refers to evaluating the performance of an LLM when used in a specific context. + +Unlike unit tests, evals are an emerging art/science, anyone who tells you they know exactly how evals should be defined can safely be ignored. + +Evals are generally more like benchmarks than unit tests, they never "pass" although they do "fail"; you care mostly about how they change over time. + +### System prompt customization + +The system prompt is the developer's primary tool in controlling the LLM's behavior, so it's often useful to be able to customise the system prompt and see how performance changes. -evals are more like benchmarks, they never "pass" although they do "fail", you care mostly about how they change over time, we (and we think most other people) don't really know what a "good" eval is, we provide some useful tools, we'll improve this if/when a common best practice emerges, or we think we have something interesting to say +TODO example of customizing system prompt through deps. diff --git a/pydantic_ai_slim/pydantic_ai/_result.py b/pydantic_ai_slim/pydantic_ai/_result.py index 26d7ace4..06a61826 100644 --- a/pydantic_ai_slim/pydantic_ai/_result.py +++ b/pydantic_ai_slim/pydantic_ai/_result.py @@ -191,7 +191,7 @@ def validate( ) else: result = self.type_adapter.validate_python( - tool_call.args.args_object, experimental_allow_partial=pyd_allow_partial + tool_call.args.args_dict, experimental_allow_partial=pyd_allow_partial ) except ValidationError as e: if wrap_validation_errors: diff --git a/pydantic_ai_slim/pydantic_ai/_tool.py b/pydantic_ai_slim/pydantic_ai/_tool.py index b808c305..6ebcad44 100644 --- a/pydantic_ai_slim/pydantic_ai/_tool.py +++ b/pydantic_ai_slim/pydantic_ai/_tool.py @@ -59,7 +59,7 @@ async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Mes if isinstance(message.args, messages.ArgsJson): args_dict = self.validator.validate_json(message.args.args_json) else: - args_dict = self.validator.validate_python(message.args.args_object) + args_dict = self.validator.validate_python(message.args.args_dict) except ValidationError as e: return self._on_error(e, message) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 9331d25c..1c881cd9 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -144,8 +144,8 @@ class ArgsJson: @dataclass -class ArgsObject: - args_object: dict[str, Any] +class ArgsDict: + args_dict: dict[str, Any] """A python dictionary of arguments.""" @@ -155,7 +155,7 @@ class ToolCall: tool_name: str """The name of the tool to call.""" - args: ArgsJson | ArgsObject + args: ArgsJson | ArgsDict """The arguments to pass to the tool. Either as JSON or a Python dictionary depending on how data was returned. @@ -168,12 +168,12 @@ def from_json(cls, tool_name: str, args_json: str, tool_id: str | None = None) - return cls(tool_name, ArgsJson(args_json), tool_id) @classmethod - def from_object(cls, tool_name: str, args_object: dict[str, Any]) -> ToolCall: - return cls(tool_name, ArgsObject(args_object)) + def from_dict(cls, tool_name: str, args_dict: dict[str, Any]) -> ToolCall: + return cls(tool_name, ArgsDict(args_dict)) def has_content(self) -> bool: - if isinstance(self.args, ArgsObject): - return any(self.args.args_object.values()) + if isinstance(self.args, ArgsDict): + return any(self.args.args_dict.values()) else: return bool(self.args.args_json) diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 135aa3cf..9ca06d00 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -250,7 +250,7 @@ def _estimate_cost(messages: Iterable[Message]) -> result.Cost: if isinstance(call.args, ArgsJson): args_str = call.args.args_json else: - args_str = pydantic_core.to_json(call.args.args_object).decode() + args_str = pydantic_core.to_json(call.args.args_dict).decode() response_tokens += 1 + _string_cost(args_str) else: diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 08241a7e..07fc7721 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -16,7 +16,7 @@ from .. import UnexpectedModelBehavior, _pydantic, _utils, exceptions, result from ..messages import ( - ArgsObject, + ArgsDict, Message, ModelAnyResponse, ModelStructuredResponse, @@ -420,15 +420,15 @@ class _GeminiFunctionCallPart(TypedDict): def _function_call_part_from_call(tool: ToolCall) -> _GeminiFunctionCallPart: - assert isinstance(tool.args, ArgsObject), f'Expected ArgsObject, got {tool.args}' - return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args.args_object)) + assert isinstance(tool.args, ArgsDict), f'Expected ArgsObject, got {tool.args}' + return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args.args_dict)) def _structured_response_from_parts( parts: list[_GeminiFunctionCallPart], timestamp: datetime | None = None ) -> ModelStructuredResponse: return ModelStructuredResponse( - calls=[ToolCall.from_object(part['function_call']['name'], part['function_call']['args']) for part in parts], + calls=[ToolCall.from_dict(part['function_call']['name'], part['function_call']['args']) for part in parts], timestamp=timestamp or _utils.now_utc(), ) diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index a3f31620..bdcd5c00 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -5,7 +5,7 @@ from collections.abc import AsyncIterator, Iterable, Iterator, Mapping, Sequence from contextlib import asynccontextmanager from dataclasses import dataclass, field -from datetime import datetime +from datetime import date, datetime, timedelta from typing import Any, Literal import pydantic_core @@ -140,7 +140,7 @@ def gen_tool_args(self, tool_def: AbstractToolDefinition) -> Any: def _request(self, messages: list[Message]) -> ModelAnyResponse: if self.step == 0 and self.tool_calls: - calls = [ToolCall.from_object(name, self.gen_tool_args(args)) for name, args in self.tool_calls] + calls = [ToolCall.from_dict(name, self.gen_tool_args(args)) for name, args in self.tool_calls] self.step += 1 self.last_message_count = len(messages) return ModelStructuredResponse(calls=calls) @@ -150,7 +150,7 @@ def _request(self, messages: list[Message]) -> ModelAnyResponse: new_retry_names = {m.tool_name for m in new_messages if isinstance(m, RetryPrompt)} if new_retry_names: calls = [ - ToolCall.from_object(name, self.gen_tool_args(args)) + ToolCall.from_dict(name, self.gen_tool_args(args)) for name, args in self.tool_calls if name in new_retry_names ] @@ -177,11 +177,11 @@ def _request(self, messages: list[Message]) -> ModelAnyResponse: result_tool = self.result_tools[self.seed % len(self.result_tools)] if custom_result_args is not None: self.step += 1 - return ModelStructuredResponse(calls=[ToolCall.from_object(result_tool.name, custom_result_args)]) + return ModelStructuredResponse(calls=[ToolCall.from_dict(result_tool.name, custom_result_args)]) else: response_args = self.gen_tool_args(result_tool) self.step += 1 - return ModelStructuredResponse(calls=[ToolCall.from_object(result_tool.name, response_args)]) + return ModelStructuredResponse(calls=[ToolCall.from_dict(result_tool.name, response_args)]) @dataclass @@ -321,7 +321,7 @@ def _str_gen(self, schema: dict[str, Any]) -> str: if fmt := schema.get('format'): if fmt == 'date': - return '2024-01-01' + return (date(2024, 1, 1) + timedelta(days=self.seed)).isoformat() return self._char() diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 476b4ec7..fb24e975 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -14,7 +14,7 @@ from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehavior, UserError, _utils from pydantic_ai.messages import ( - ArgsObject, + ArgsDict, ModelStructuredResponse, ModelTextResponse, RetryPrompt, @@ -422,7 +422,7 @@ async def test_text_success(get_gemini_client: GetGeminiClient): async def test_request_structured_response(get_gemini_client: GetGeminiClient): response = gemini_response( _content_function_call( - ModelStructuredResponse(calls=[ToolCall.from_object('final_result', {'response': [1, 2, 123]})]) + ModelStructuredResponse(calls=[ToolCall.from_dict('final_result', {'response': [1, 2, 123]})]) ) ) gemini_client = get_gemini_client(response) @@ -438,7 +438,7 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient): calls=[ ToolCall( tool_name='final_result', - args=ArgsObject(args_object={'response': [1, 2, 123]}), + args=ArgsDict(args_dict={'response': [1, 2, 123]}), ) ], timestamp=IsNow(tz=timezone.utc), @@ -451,12 +451,12 @@ async def test_request_tool_call(get_gemini_client: GetGeminiClient): responses = [ gemini_response( _content_function_call( - ModelStructuredResponse(calls=[ToolCall.from_object('get_location', {'loc_name': 'San Fransisco'})]) + ModelStructuredResponse(calls=[ToolCall.from_dict('get_location', {'loc_name': 'San Fransisco'})]) ) ), gemini_response( _content_function_call( - ModelStructuredResponse(calls=[ToolCall.from_object('get_location', {'loc_name': 'London'})]) + ModelStructuredResponse(calls=[ToolCall.from_dict('get_location', {'loc_name': 'London'})]) ) ), gemini_response(_content_model_text('final response')), @@ -482,7 +482,7 @@ async def get_location(loc_name: str) -> str: calls=[ ToolCall( tool_name='get_location', - args=ArgsObject(args_object={'loc_name': 'San Fransisco'}), + args=ArgsDict(args_dict={'loc_name': 'San Fransisco'}), ) ], timestamp=IsNow(tz=timezone.utc), @@ -494,7 +494,7 @@ async def get_location(loc_name: str) -> str: calls=[ ToolCall( tool_name='get_location', - args=ArgsObject(args_object={'loc_name': 'London'}), + args=ArgsDict(args_dict={'loc_name': 'London'}), ) ], timestamp=IsNow(tz=timezone.utc), @@ -531,7 +531,7 @@ async def test_heterogeneous_responses(get_gemini_client: GetGeminiClient): _function_call_part_from_call( ToolCall( tool_name='get_location', - args=ArgsObject(args_object={'loc_name': 'San Fransisco'}), + args=ArgsDict(args_dict={'loc_name': 'San Fransisco'}), ) ), ], @@ -587,7 +587,7 @@ async def test_stream_structured(get_gemini_client: GetGeminiClient): responses = [ gemini_response( _content_function_call( - ModelStructuredResponse(calls=[ToolCall.from_object('final_result', {'response': [1, 2]})]) + ModelStructuredResponse(calls=[ToolCall.from_dict('final_result', {'response': [1, 2]})]) ), ), ] @@ -606,10 +606,10 @@ async def test_stream_structured(get_gemini_client: GetGeminiClient): async def test_stream_structured_tool_calls(get_gemini_client: GetGeminiClient): first_responses = [ gemini_response( - _content_function_call(ModelStructuredResponse(calls=[ToolCall.from_object('foo', {'x': 'a'})])), + _content_function_call(ModelStructuredResponse(calls=[ToolCall.from_dict('foo', {'x': 'a'})])), ), gemini_response( - _content_function_call(ModelStructuredResponse(calls=[ToolCall.from_object('bar', {'y': 'b'})])), + _content_function_call(ModelStructuredResponse(calls=[ToolCall.from_dict('bar', {'y': 'b'})])), ), ] d1 = _gemini_streamed_response_ta.dump_json(first_responses, by_alias=True) @@ -618,7 +618,7 @@ async def test_stream_structured_tool_calls(get_gemini_client: GetGeminiClient): second_responses = [ gemini_response( _content_function_call( - ModelStructuredResponse(calls=[ToolCall.from_object('final_result', {'response': [1, 2]})]) + ModelStructuredResponse(calls=[ToolCall.from_dict('final_result', {'response': [1, 2]})]) ), ), ] @@ -649,8 +649,8 @@ async def bar(y: str) -> str: UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), ModelStructuredResponse( calls=[ - ToolCall(tool_name='foo', args=ArgsObject(args_object={'x': 'a'})), - ToolCall(tool_name='bar', args=ArgsObject(args_object={'y': 'b'})), + ToolCall(tool_name='foo', args=ArgsDict(args_dict={'x': 'a'})), + ToolCall(tool_name='bar', args=ArgsDict(args_dict={'y': 'b'})), ], timestamp=IsNow(tz=timezone.utc), ), @@ -660,7 +660,7 @@ async def bar(y: str) -> str: calls=[ ToolCall( tool_name='final_result', - args=ArgsObject(args_object={'response': [1, 2]}), + args=ArgsDict(args_dict={'response': [1, 2]}), ) ], timestamp=IsNow(tz=timezone.utc), @@ -681,7 +681,7 @@ async def test_stream_text_heterogeneous(get_gemini_client: GetGeminiClient): _function_call_part_from_call( ToolCall( tool_name='get_location', - args=ArgsObject(args_object={'loc_name': 'San Fransisco'}), + args=ArgsDict(args_dict={'loc_name': 'San Fransisco'}), ) ), ], diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index 44e01aa2..c82fb978 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -327,11 +327,11 @@ def test_call_all(): UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), ModelStructuredResponse( calls=[ - ToolCall.from_object('foo', {'x': 0}), - ToolCall.from_object('bar', {'x': 0}), - ToolCall.from_object('baz', {'x': 0}), - ToolCall.from_object('qux', {'x': 0}), - ToolCall.from_object('quz', {'x': 'a'}), + ToolCall.from_dict('foo', {'x': 0}), + ToolCall.from_dict('bar', {'x': 0}), + ToolCall.from_dict('baz', {'x': 0}), + ToolCall.from_dict('qux', {'x': 0}), + ToolCall.from_dict('quz', {'x': 'a'}), ], timestamp=IsNow(tz=timezone.utc), ), @@ -376,7 +376,7 @@ async def try_again(messages: list[Message], _: AgentInfo) -> ModelAnyResponse: nonlocal call_count call_count += 1 - return ModelStructuredResponse(calls=[ToolCall.from_object('final_result', {'x': call_count})]) + return ModelStructuredResponse(calls=[ToolCall.from_dict('final_result', {'x': call_count})]) class Foo(BaseModel): x: int diff --git a/tests/models/test_model_test.py b/tests/models/test_model_test.py index fac96568..ae47d223 100644 --- a/tests/models/test_model_test.py +++ b/tests/models/test_model_test.py @@ -94,11 +94,11 @@ async def my_ret(x: int) -> str: [ UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), ModelStructuredResponse( - calls=[ToolCall.from_object('my_ret', {'x': 0})], + calls=[ToolCall.from_dict('my_ret', {'x': 0})], timestamp=IsNow(tz=timezone.utc), ), RetryPrompt(tool_name='my_ret', content='First call failed', timestamp=IsNow(tz=timezone.utc)), - ModelStructuredResponse(calls=[ToolCall.from_object('my_ret', {'x': 0})], timestamp=IsNow(tz=timezone.utc)), + ModelStructuredResponse(calls=[ToolCall.from_dict('my_ret', {'x': 0})], timestamp=IsNow(tz=timezone.utc)), ToolReturn(tool_name='my_ret', content='1', timestamp=IsNow(tz=timezone.utc)), ModelTextResponse(content='{"my_ret":"1"}', timestamp=IsNow(tz=timezone.utc)), ] diff --git a/tests/test_agent.py b/tests/test_agent.py index 53b92e33..c10457dc 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -8,8 +8,8 @@ from pydantic_ai import Agent, CallContext, ModelRetry, UnexpectedModelBehavior, UserError from pydantic_ai.messages import ( + ArgsDict, ArgsJson, - ArgsObject, Message, ModelAnyResponse, ModelStructuredResponse, @@ -370,7 +370,7 @@ async def ret_a(x: str) -> str: [ UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), ModelStructuredResponse( - calls=[ToolCall(tool_name='ret_a', args=ArgsObject(args_object={'x': 'a'}))], + calls=[ToolCall(tool_name='ret_a', args=ArgsDict(args_dict={'x': 'a'}))], timestamp=IsNow(tz=timezone.utc), ), ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)), @@ -387,7 +387,7 @@ async def ret_a(x: str) -> str: SystemPrompt(content='Foobar'), UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), ModelStructuredResponse( - calls=[ToolCall(tool_name='ret_a', args=ArgsObject(args_object={'x': 'a'}))], + calls=[ToolCall(tool_name='ret_a', args=ArgsDict(args_dict={'x': 'a'}))], timestamp=IsNow(tz=timezone.utc), ), ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)), @@ -395,7 +395,7 @@ async def ret_a(x: str) -> str: # second call, notice no repeated system prompt UserPrompt(content='Hello again', timestamp=IsNow(tz=timezone.utc)), ModelStructuredResponse( - calls=[ToolCall(tool_name='ret_a', args=ArgsObject(args_object={'x': 'a'}))], + calls=[ToolCall(tool_name='ret_a', args=ArgsDict(args_dict={'x': 'a'}))], timestamp=IsNow(tz=timezone.utc), ), ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)), @@ -420,7 +420,7 @@ async def ret_a(x: str) -> str: SystemPrompt(content='Foobar'), UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), ModelStructuredResponse( - calls=[ToolCall(tool_name='ret_a', args=ArgsObject(args_object={'x': 'a'}))], + calls=[ToolCall(tool_name='ret_a', args=ArgsDict(args_dict={'x': 'a'}))], timestamp=IsNow(tz=timezone.utc), ), ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)), @@ -428,7 +428,7 @@ async def ret_a(x: str) -> str: # second call, notice no repeated system prompt UserPrompt(content='Hello again', timestamp=IsNow(tz=timezone.utc)), ModelStructuredResponse( - calls=[ToolCall(tool_name='ret_a', args=ArgsObject(args_object={'x': 'a'}))], + calls=[ToolCall(tool_name='ret_a', args=ArgsDict(args_dict={'x': 'a'}))], timestamp=IsNow(tz=timezone.utc), ), ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)), diff --git a/tests/test_examples.py b/tests/test_examples.py index 56851091..c9a46b24 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -17,7 +17,7 @@ from pydantic_ai._utils import group_by_temporal from pydantic_ai.messages import ( - ArgsObject, + ArgsDict, Message, ModelAnyResponse, ModelStructuredResponse, @@ -161,8 +161,8 @@ def test_docs_examples( eval_example.print_callback = print_callback call_name = 'main' - for name in ('test_application_code', 'test_forecast_success'): - if f'def {name}' in example.source: + for name in ('test_application_code', 'test_forecast', 'test_forecast_future'): + if f'def {name}():' in example.source: call_name = name break @@ -208,20 +208,20 @@ async def async_http_request(url: str, **kwargs: Any) -> httpx.Response: 'Who was Albert Einstein?': 'Albert Einstein was a German-born theoretical physicist.', 'What was his most famous equation?': "Albert Einstein's most famous equation is (E = mc^2).", 'What is the date?': 'Hello Frank, the date today is 2032-01-02.', - 'Put my money on square eighteen': ToolCall(tool_name='roulette_wheel', args=ArgsObject({'square': 18})), - 'I bet five is the winner': ToolCall(tool_name='roulette_wheel', args=ArgsObject({'square': 5})), - 'My guess is 4': ToolCall(tool_name='roll_die', args=ArgsObject({})), + 'Put my money on square eighteen': ToolCall(tool_name='roulette_wheel', args=ArgsDict({'square': 18})), + 'I bet five is the winner': ToolCall(tool_name='roulette_wheel', args=ArgsDict({'square': 5})), + 'My guess is 4': ToolCall(tool_name='roll_die', args=ArgsDict({})), 'Send a message to John Doe asking for coffee next week': ToolCall( - tool_name='get_user_by_name', args=ArgsObject({'name': 'John'}) + tool_name='get_user_by_name', args=ArgsDict({'name': 'John'}) ), - 'Please get me the volume of a box with size 6.': ToolCall(tool_name='calc_volume', args=ArgsObject({'size': 6})), + 'Please get me the volume of a box with size 6.': ToolCall(tool_name='calc_volume', args=ArgsDict({'size': 6})), 'Where does "hello world" come from?': ( 'The first known use of "hello, world" was in a 1974 textbook about the C programming language.' ), - 'What is my balance?': ToolCall(tool_name='customer_balance', args=ArgsObject({'include_pending': True})), + 'What is my balance?': ToolCall(tool_name='customer_balance', args=ArgsDict({'include_pending': True})), 'I just lost my card!': ToolCall( tool_name='final_result', - args=ArgsObject( + args=ArgsDict( { 'support_advice': ( "I'm sorry to hear that, John. " @@ -234,28 +234,28 @@ async def async_http_request(url: str, **kwargs: Any) -> httpx.Response: ), 'Where the olympics held in 2012?': ToolCall( tool_name='final_result', - args=ArgsObject({'city': 'London', 'country': 'United Kingdom'}), + args=ArgsDict({'city': 'London', 'country': 'United Kingdom'}), ), 'The box is 10x20x30': 'Please provide the units for the dimensions (e.g., cm, in, m).', 'The box is 10x20x30 cm': ToolCall( tool_name='final_result', - args=ArgsObject({'width': 10, 'height': 20, 'depth': 30, 'units': 'cm'}), + args=ArgsDict({'width': 10, 'height': 20, 'depth': 30, 'units': 'cm'}), ), 'red square, blue circle, green triangle': ToolCall( tool_name='final_result_list', - args=ArgsObject({'response': ['red', 'blue', 'green']}), + args=ArgsDict({'response': ['red', 'blue', 'green']}), ), 'square size 10, circle size 20, triangle size 30': ToolCall( tool_name='final_result_list_2', - args=ArgsObject({'response': [10, 20, 30]}), + args=ArgsDict({'response': [10, 20, 30]}), ), 'get me uses who were last active yesterday.': ToolCall( tool_name='final_result_Success', - args=ArgsObject({'sql_query': 'SELECT * FROM users WHERE last_active::date = today() - interval 1 day'}), + args=ArgsDict({'sql_query': 'SELECT * FROM users WHERE last_active::date = today() - interval 1 day'}), ), 'My name is Ben, I was born on January 28th 1990, I like the chain the dog and the pyramid.': ToolCall( tool_name='final_result', - args=ArgsObject( + args=ArgsDict( { 'name': 'Ben', 'dob': '1990-01-28', @@ -277,30 +277,30 @@ async def model_logic(messages: list[Message], info: AgentInfo) -> ModelAnyRespo elif m.role == 'tool-return' and m.tool_name == 'roulette_wheel': win = m.content == 'winner' - return ModelStructuredResponse(calls=[ToolCall(tool_name='final_result', args=ArgsObject({'response': win}))]) + return ModelStructuredResponse(calls=[ToolCall(tool_name='final_result', args=ArgsDict({'response': win}))]) elif m.role == 'tool-return' and m.tool_name == 'roll_die': - return ModelStructuredResponse(calls=[ToolCall(tool_name='get_player_name', args=ArgsObject({}))]) + return ModelStructuredResponse(calls=[ToolCall(tool_name='get_player_name', args=ArgsDict({}))]) elif m.role == 'tool-return' and m.tool_name == 'get_player_name': return ModelTextResponse(content="Congratulations Adam, you guessed correctly! You're a winner!") if m.role == 'retry-prompt' and isinstance(m.content, str) and m.content.startswith("No user found with name 'Joh"): return ModelStructuredResponse( - calls=[ToolCall(tool_name='get_user_by_name', args=ArgsObject({'name': 'John Doe'}))] + calls=[ToolCall(tool_name='get_user_by_name', args=ArgsDict({'name': 'John Doe'}))] ) elif m.role == 'tool-return' and m.tool_name == 'get_user_by_name': args = { 'message': 'Hello John, would you be free for coffee sometime next week? Let me know what works for you!', 'user_id': 123, } - return ModelStructuredResponse(calls=[ToolCall(tool_name='final_result', args=ArgsObject(args))]) + return ModelStructuredResponse(calls=[ToolCall(tool_name='final_result', args=ArgsDict(args))]) elif m.role == 'retry-prompt' and m.tool_name == 'calc_volume': - return ModelStructuredResponse(calls=[ToolCall(tool_name='calc_volume', args=ArgsObject({'size': 6}))]) + return ModelStructuredResponse(calls=[ToolCall(tool_name='calc_volume', args=ArgsDict({'size': 6}))]) elif m.role == 'tool-return' and m.tool_name == 'customer_balance': args = { 'support_advice': 'Hello John, your current account balance, including pending transactions, is $123.45.', 'block_card': False, 'risk': 1, } - return ModelStructuredResponse(calls=[ToolCall(tool_name='final_result', args=ArgsObject(args))]) + return ModelStructuredResponse(calls=[ToolCall(tool_name='final_result', args=ArgsDict(args))]) else: sys.stdout.write(str(debug.format(messages, info))) raise RuntimeError(f'Unexpected message: {m}') @@ -324,8 +324,8 @@ async def stream_model_logic( yield ' '.join(chunk) return else: - if isinstance(response.args, ArgsObject): - json_text = pydantic_core.to_json(response.args.args_object).decode() + if isinstance(response.args, ArgsDict): + json_text = pydantic_core.to_json(response.args.args_dict).decode() else: json_text = response.args.args_json diff --git a/tests/test_logfire.py b/tests/test_logfire.py index 0141fca6..1102ee48 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -98,7 +98,7 @@ async def my_ret(x: int) -> str: 'logfire.span_type': 'span', 'response': IsJson( { - 'calls': [{'tool_name': 'my_ret', 'args': {'args_object': {'x': 0}}, 'tool_id': None}], + 'calls': [{'tool_name': 'my_ret', 'args': {'args_dict': {'x': 0}}, 'tool_id': None}], 'timestamp': IsStr() & IsNow(iso_string=True, tz=timezone.utc), 'role': 'model-structured-response', } @@ -124,7 +124,7 @@ async def my_ret(x: int) -> str: 'properties': { 'args': { 'type': 'object', - 'title': 'ArgsObject', + 'title': 'ArgsDict', 'x-python-datatype': 'dataclass', } }, diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 4b272a23..d19f28b7 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -9,8 +9,8 @@ from pydantic_ai import Agent, UnexpectedModelBehavior, UserError from pydantic_ai.messages import ( + ArgsDict, ArgsJson, - ArgsObject, Message, ModelStructuredResponse, ModelTextResponse, @@ -44,7 +44,7 @@ async def ret_a(x: str) -> str: [ UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), ModelStructuredResponse( - calls=[ToolCall(tool_name='ret_a', args=ArgsObject(args_object={'x': 'a'}))], + calls=[ToolCall(tool_name='ret_a', args=ArgsDict(args_dict={'x': 'a'}))], timestamp=IsNow(tz=timezone.utc), ), ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)), @@ -59,7 +59,7 @@ async def ret_a(x: str) -> str: [ UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)), ModelStructuredResponse( - calls=[ToolCall(tool_name='ret_a', args=ArgsObject(args_object={'x': 'a'}))], + calls=[ToolCall(tool_name='ret_a', args=ArgsDict(args_dict={'x': 'a'}))], timestamp=IsNow(tz=timezone.utc), ), ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)), From 1da316a1f72744e063699b58f21e415c3c997f4d Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 28 Nov 2024 17:27:58 +0000 Subject: [PATCH 3/4] fix for python 3.9 --- pydantic_ai_slim/pydantic_ai/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index daa59874..b8879609 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -8,11 +8,11 @@ from datetime import datetime, timezone from functools import partial from types import GenericAlias -from typing import Any, Callable, Generic, TypeGuard, TypeVar, Union, cast, overload +from typing import Any, Callable, Generic, TypeVar, Union, cast, overload from pydantic import BaseModel from pydantic.json_schema import JsonSchemaValue -from typing_extensions import ParamSpec, TypeAlias, is_typeddict +from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict _P = ParamSpec('_P') _R = TypeVar('_R') From 48022bd4485b5cb27fe93f43a83032699266f7de Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 29 Nov 2024 10:47:39 +0000 Subject: [PATCH 4/4] tweaks and cleanup --- docs/testing-evals.md | 111 +++++++++++++++++++------------- pydantic_ai_examples/sql_gen.py | 6 +- 2 files changed, 72 insertions(+), 45 deletions(-) diff --git a/docs/testing-evals.md b/docs/testing-evals.md index 36e80165..35ecee76 100644 --- a/docs/testing-evals.md +++ b/docs/testing-evals.md @@ -1,6 +1,6 @@ # Testing and Evals -When thinking about PydanticAI use and LLM integrations in general, there are two distinct kinds of test: +With PydanticAI and LLM integrations in general, there are two distinct kinds of test: 1. **Unit tests** — tests of your application code, and whether it's behaving correctly 2. **"Evals"** — tests of the LLM, and how good or bad its responses are @@ -16,11 +16,11 @@ Because for the most part they're nothing new, we have pretty well established t Unless you're really sure you know better, you'll probably want to follow roughly this strategy: * Use [`pytest`](https://docs.pytest.org/en/stable/) as your test harness -* If you find yourself typing out long assertions, use [`inline-snapshot`](https://15r10nk.github.io/inline-snapshot/latest/) +* If you find yourself typing out long assertions, use [inline-snapshot](https://15r10nk.github.io/inline-snapshot/latest/) * Similarly, [dirty-equals](https://dirty-equals.helpmanual.io/latest/) can be useful for comparing large data structures * Use [`TestModel`][pydantic_ai.models.test.TestModel] or [`FunctionModel`][pydantic_ai.models.function.FunctionModel] in place of your actual model to avoid the cost, latency and variability of real LLM calls -* Use [`Agent.override`][pydantic_ai.agent.Agent.override] to replace your model inside your application logic. -* Set [`ALLOW_MODEL_REQUESTS=False`][pydantic_ai.models.ALLOW_MODEL_REQUESTS] globally to block any requests from being made to non-test models +* Use [`Agent.override`][pydantic_ai.agent.Agent.override] to replace your model inside your application logic +* Set [`ALLOW_MODEL_REQUESTS=False`][pydantic_ai.models.ALLOW_MODEL_REQUESTS] globally to block any requests from being made to non-test models accidentally ### Unit testing with `TestModel` @@ -34,7 +34,7 @@ The simplest and fastest way to exercise most of your application code is using The resulting data won't look pretty or relevant, but it should pass Pydantic's validation in most cases. If you want something more sophisticated, use [`FunctionModel`][pydantic_ai.models.function.FunctionModel] and write your own data generation logic. -Let's consider the following application code: +Let's write unit tests for the following application code: ```py title="weather_app.py" import asyncio @@ -56,8 +56,7 @@ weather_agent = Agent( def weather_forecast( ctx: CallContext[WeatherService], location: str, forecast_date: date ) -> str: - if forecast_date < date.today(): - # (pretend) use the cheaper endpoint to get historical data + if forecast_date < date.today(): # (3)! return ctx.deps.get_historic_weather(location, forecast_date) else: return ctx.deps.get_forecast(location, forecast_date) @@ -66,12 +65,7 @@ def weather_forecast( async def run_weather_forecast( # (3)! user_prompts: list[tuple[str, int]], conn: DatabaseConn ): - """Run weather forecast for a list of user prompts. - - Args: - user_prompts: A list of tuples containing the user prompt and user id. - conn: A database connection to store the forecast results. - """ + """Run weather forecast for a list of user prompts and save.""" async with WeatherService() as weather_service: async def run_forecast(prompt: str, user_id: int): @@ -85,12 +79,13 @@ async def run_weather_forecast( # (3)! ``` 1. `DatabaseConn` is a class that holds a database connection -2. `WeatherService` is a class that provides weather data -3. This function is the code we want to test, together with the agent it uses +2. `WeatherService` has methods to get weather forecasts and historic data about the weather +3. We need to call a different endpoint depending on whether the date is in the past or the future, you'll see why this nuance is important below +4. This function is the code we want to test, together with the agent it uses Here we have a function that takes a list of `#!python (user_prompt, user_id)` tuples, gets a weather forecast for each prompt, and stores the result in the database. -We want to test this code without having to mock certain objects or modify our code so we can pass test objects in. +**We want to test this code without having to mock certain objects or modify our code so we can pass test objects in.** Here's how we would write tests using [`TestModel`][pydantic_ai.models.test.TestModel]: @@ -144,7 +139,10 @@ async def test_forecast(): ToolCall( tool_name='weather_forecast', args=ArgsDict( - args_dict={'location': 'a', 'forecast_date': '2024-01-01'} + args_dict={ + 'location': 'a', + 'forecast_date': '2024-01-01', # (8)! + } ), tool_id=None, ) @@ -168,12 +166,13 @@ async def test_forecast(): ``` 1. We're using [anyio](https://anyio.readthedocs.io/en/stable/) to run async tests. -2. This is a safety measure to make sure we don't accidentally make real requests to the LLM while testing. -3. We're using [`override`][pydantic_ai.agent.Agent.override] to replace the agent's model with [`TestModel`][pydantic_ai.models.test.TestModel]. -4. Now we call the function we want to test. +2. This is a safety measure to make sure we don't accidentally make real requests to the LLM while testing, see [`ALLOW_MODEL_REQUESTS`][pydantic_ai.models.ALLOW_MODEL_REQUESTS] for more details. +3. We're using [`Agent.override`][pydantic_ai.agent.Agent.override] to replace the agent's model with [`TestModel`][pydantic_ai.models.test.TestModel], the nice thing about `override` is that we can replace the model inside agent without needing access to the agent `run*` methods call site. +4. Now we call the function we want to test inside the `override` context manager. 5. But default, `TestModel` will return a JSON string summarising the tools calls made, and what was returned. If you wanted to customise the response to something more closely aligned with the domain, you could add [`custom_result_text='Sunny'`][pydantic_ai.models.test.TestModel.custom_result_text] when defining `TestModel`. 6. So far we don't actually know which tools were called and with which values, we can use the [`last_run_messages`][pydantic_ai.agent.Agent.last_run_messages] attribute to inspect messages from the most recent run and assert the exchange between the agent and the model occurred as expected. 7. The [`IsNow`][dirty_equals.IsNow] helper allows us to use declarative asserts even with data which will contain timestamps that change over time. +8. `TestModel` isn't doing anything clever to extract values from the prompt, so these values are hardcoded. ### Unit testing with `FunctionModel` @@ -205,30 +204,31 @@ pytestmark = pytest.mark.anyio models.ALLOW_MODEL_REQUESTS = False -async def test_forecast_future(): - def call_weather_forecast( # (1)! - messages: list[Message], info: AgentInfo - ) -> ModelAnyResponse: - if len(messages) == 2: - # first call, call the weather forecast tool - assert set(info.function_tools.keys()) == {'weather_forecast'} - - user_prompt = messages[1] - m = re.search(r'\d{4}-\d{2}-\d{2}', user_prompt.content) - assert m is not None - args = {'location': 'London', 'forecast_date': m.group()} - return ModelStructuredResponse( - calls=[ToolCall.from_dict('weather_forecast', args)] - ) - else: - # second call, return the forecast - msg = messages[-1] - assert msg.role == 'tool-return' - return ModelTextResponse(f'The forecast is: {msg.content}') +def call_weather_forecast( # (1)! + messages: list[Message], info: AgentInfo +) -> ModelAnyResponse: + if len(messages) == 2: + # first call, call the weather forecast tool + assert set(info.function_tools.keys()) == {'weather_forecast'} + + user_prompt = messages[1] + m = re.search(r'\d{4}-\d{2}-\d{2}', user_prompt.content) + assert m is not None + args = {'location': 'London', 'forecast_date': m.group()} # (2)! + return ModelStructuredResponse( + calls=[ToolCall.from_dict('weather_forecast', args)] + ) + else: + # second call, return the forecast + msg = messages[-1] + assert msg.role == 'tool-return' + return ModelTextResponse(f'The forecast is: {msg.content}') + +async def test_forecast_future(): conn = DatabaseConn() user_id = 1 - with weather_agent.override(model=FunctionModel(call_weather_forecast)): # (2)! + with weather_agent.override(model=FunctionModel(call_weather_forecast)): # (3)! prompt = 'What will the weather be like in London on 2032-01-01?' await run_weather_forecast([(prompt, user_id)], conn) @@ -236,8 +236,33 @@ async def test_forecast_future(): assert forecast == 'The forecast is: Rainy with a chance of sun' ``` -1. We define a function `call_weather_forecast` that will be called by `FunctionModel` in place of the LLM, this function has access to the list of [`Message`s][pydantic_ai.messages.Message] that make up the run, and [`AgentInfo`][pydantic_ai.models.function.AgentInfo] which contains information about the agent and the function tools and return type tools. -2. We use [`FunctionModel`][pydantic_ai.models.function.FunctionModel] to replace the agent's model with our custom function. +1. We define a function `call_weather_forecast` that will be called by `FunctionModel` in place of the LLM, this function has access to the list of [`Message`][pydantic_ai.messages.Message]s that make up the run, and [`AgentInfo`][pydantic_ai.models.function.AgentInfo] which contains information about the agent and the function tools and return tools. +2. Our function is slightly intelligent in that it tries to extract a date from the prompt, but just hard codes the location. +3. We use [`FunctionModel`][pydantic_ai.models.function.FunctionModel] to replace the agent's model with our custom function. + +### Overriding model via pytest fixtures + +If you're writing lots of tests that all require model to be overridden, you can use [pytest fixtures](https://docs.pytest.org/en/6.2.x/fixture.html) to override the model with [`TestModel`][pydantic_ai.models.test.TestModel] or [`FunctionModel`][pydantic_ai.models.function.FunctionModel] in a reusable way. + +Here's an example of a fixture that overrides the model with `TestModel`: + +```py title="tests.py" +import pytest +from weather_app import weather_agent + +from pydantic_ai.models.test import TestModel + + +@pytest.fixture +def override_weather_agent(): + with weather_agent.override(model=TestModel()): + yield + + +async def test_forecast(override_weather_agent: None): + ... + # test code here +``` ## Evals diff --git a/pydantic_ai_examples/sql_gen.py b/pydantic_ai_examples/sql_gen.py index 843ebc13..e9ba89e8 100644 --- a/pydantic_ai_examples/sql_gen.py +++ b/pydantic_ai_examples/sql_gen.py @@ -32,7 +32,7 @@ logfire.instrument_asyncpg() DB_SCHEMA = """ -CREATE TABLE IF NOT EXISTS records ( +CREATE TABLE records ( created_at timestamptz, start_timestamp timestamptz, end_timestamp timestamptz, @@ -73,7 +73,7 @@ class InvalidRequest(BaseModel): Response: TypeAlias = Union[Success, InvalidRequest] -agent: Agent[Deps, Response] = Agent( +agent = Agent( 'gemini-1.5-flash', # Type ignore while we wait for PEP-0747, nonetheless unions will work fine everywhere else result_type=Response, # type: ignore @@ -87,6 +87,8 @@ async def system_prompt() -> str: Given the following PostgreSQL table of records, your job is to write a SQL query that suits the user's request. +Database schema: + {DB_SCHEMA} today's date = {date.today()}