Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

evals docs #100

Merged
merged 3 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 180 additions & 4 deletions docs/testing-evals.md
Original file line number Diff line number Diff line change
Expand Up @@ -266,14 +266,190 @@ async def test_forecast(override_weather_agent: None):

## Evals

"Evals" refers to evaluating the performance of an LLM when used in a specific context.
"Evals" refers to evaluating a models performance for a specific application.

Unlike unit tests, evals are an emerging art/science, anyone who tells you they know exactly how evals should be defined can safely be ignored.
!!! danger "Warning"
Unlike unit tests, evals are an emerging art/science; anyone who claims to know for sure exactly how your evals should be defined can safely be ignored.

Evals are generally more like benchmarks than unit tests, they never "pass" although they do "fail"; you care mostly about how they change over time.

Since evals need to be run against the real model, then can be slow and expensive to run, you generally won't want to run them in CI for every commit.

### Measuring performance

The hardest part of evals is measuring how well the model has performed.

In some cases (e.g. an agent to generate SQL) there are simple, easy to run tests that can be used to measure performance (e.g. is the SQL valid? Does it return the right results? Does it return just the right results?).

In other cases (e.g. an agent that gives advice on quitting smoking) it can be very hard or impossible to make quantitative measures of performance — in the smoking case you'd really need to run a double-blind trial over months, then wait 40 years and observe health outcomes to know if changes to your prompt were an improvement.

There are a few different strategies you can use to measure performance:

* **End to end, self-contained tests** — like the SQL example, we can test the final result of the agent near-instantly
* **Synthetic self-contained tests** — writing unit test style checks that the output is as expected, checks like `#!python 'chewing gum' in response`, while these checks might seem simplistic they can be helpful, one nice characteristic is that it's easy to tell what's wrong when they fail
* **LLMs evaluating LLMs** — using another models, or even the same model with a different prompt to evaluate the performance of the agent (like when the class marks each other's homework because the teacher has a hangover), while the downsides and complexities of this approach are obvious, some think it can be a useful tool in the right circumstances
* **Evals in prod** — measuring the end results of the agent in production, then creating a quantitative measure of performance, so you can easily measure changes over time as you change the prompt or model used, [logfire](logfire.md) can be extremely useful in this case since you can write a custom query to measure the performance of your agent

### System prompt customization

The system prompt is the developer's primary tool in controlling the LLM's behavior, so it's often useful to be able to customise the system prompt and see how performance changes.
The system prompt is the developer's primary tool in controlling an agent's behavior, so it's often useful to be able to customise the system prompt and see how performance changes. This is particularly relevant when the system prompt contains a list of examples and you want to understand how changing that list affects the model's performance.

Let's assume we have the following app for running SQL generated from a user prompt (this examples omits a lot of details for brevity, see the [SQL gen](examples/sql-gen.md) example for a more complete code):

```py title="sql_app.py"
import json
from pathlib import Path
from typing import Union

from pydantic_ai import Agent, CallContext

from fake_database import DatabaseConn


class SqlSystemPrompt: # (1)!
def __init__(
self, examples: Union[list[dict[str, str]], None] = None, db: str = 'PostgreSQL'
):
if examples is None:
# if examples aren't provided, load them from file, this is the default
with Path('examples.json').open('rb') as f:
self.examples = json.load(f)
else:
self.examples = examples

self.db = db

def build_prompt(self) -> str: # (2)!
return f"""\
Given the following {self.db} table of records, your job is to
write a SQL query that suits the user's request.

Database schema:
CREATE TABLE records (
...
);

{''.join(self.format_example(example) for example in self.examples)}
"""

@staticmethod
def format_example(example: dict[str, str]) -> str: # (3)!
return f"""\
<example>
<request>{example['request']}</request>
<sql>{example['sql']}</sql>
</example>
"""


sql_agent = Agent(
'gemini-1.5-flash',
deps_type=SqlSystemPrompt,
)


@sql_agent.system_prompt
async def system_prompt(ctx: CallContext[SqlSystemPrompt]) -> str:
return ctx.deps.build_prompt()


async def user_search(user_prompt: str) -> list[dict[str, str]]:
"""Search the database based on the user's prompts."""
... # (4)!
result = await sql_agent.run(user_prompt, deps=SqlSystemPrompt())
conn = DatabaseConn()
return await conn.execute(result.data)
```

`examples.json` looks something like this:


request: show me error records with the tag "foobar"
response: SELECT * FROM records WHERE level = 'error' and 'foobar' = ANY(tags)

```json title="examples.json"
{
"examples": [
{
"request": "Show me all records",
"sql": "SELECT * FROM records;"
},
{
"request": "Show me all records from 2021",
"sql": "SELECT * FROM records WHERE date_trunc('year', date) = '2021-01-01';"
},
{
"request": "show me error records with the tag 'foobar'",
"sql": "SELECT * FROM records WHERE level = 'error' and 'foobar' = ANY(tags);"
},
...
]
}
```

Now we want a way to quantify the success of the SQL generation so we can judge how changes to the agent affect its performance.

We can use [`Agent.override`][pydantic_ai.agent.Agent.override] to replace the system prompt with a custom one that uses a subset of examples, and then run the application code (in this case `user_search`). We also run the actual SQL from the examples and compare the "correct" result from the example SQL to the SQL generated by the agent. (We compare the results of running the SQL rather than the SQL itself since the SQL might be semantically equivalent but written in a different way).

To get a quantitative measure of performance, we assign points to each run as follows:
* **-100** points if the generated SQL is invalid
* **-1** point for each row returned by the agent (so returning lots of results is discouraged)
* **+5** points for each row returned by the agent that matches the expected result

We use 5-fold cross-validation to judge the performance of the agent using our existing set of examples.

```py title="test_sql_app.py"
import json
import statistics
from pathlib import Path
from itertools import chain

from fake_database import DatabaseConn, QueryError
from sql_app import sql_agent, SqlSystemPrompt, user_search


async def main():
with Path('examples.json').open('rb') as f:
examples = json.load(f)

# split examples into 5 folds
fold_size = len(examples) // 5
folds = [examples[i : i + fold_size] for i in range(0, len(examples), fold_size)]
conn = DatabaseConn()
scores = []

for i, fold in enumerate(folds, start=1):
fold_score = 0
# build all other folds into a list of examples
other_folds = list(chain(*(f for j, f in enumerate(folds) if j != i)))
# create a new system prompt with the other fold examples
system_prompt = SqlSystemPrompt(examples=other_folds)

# override the system prompt with the new one
with sql_agent.override(deps=system_prompt):
for case in fold:
try:
agent_results = await user_search(case['request'])
except QueryError as e:
print(f'Fold {i} {case}: {e}')
fold_score -= 100
else:
# get the expected results using the SQL from this case
expected_results = await conn.execute(case['sql'])

agent_ids = [r['id'] for r in agent_results]
# each returned value has a score of -1
fold_score -= len(agent_ids)
expected_ids = {r['id'] for r in expected_results}

# each return value that matches the expected value has a score of 3
fold_score += 5 * len(set(agent_ids) & expected_ids)

scores.append(fold_score)

overall_score = statistics.mean(scores)
print(f'Overall score: {overall_score:0.2f}')
#> Overall score: 12.00
```

TODO example of customizing system prompt through deps.
We can then change the prompt, the model, or the examples and see how the score changes over time.
25 changes: 20 additions & 5 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations as _annotations

import json
import os
import re
import sys
from collections.abc import AsyncIterator, Iterable
from dataclasses import dataclass, field
from datetime import date
from pathlib import Path
from types import ModuleType
from typing import Any

Expand Down Expand Up @@ -51,8 +54,8 @@ class DatabaseConn:
users: FakeTable = field(default_factory=FakeTable)
_forecasts: dict[int, str] = field(default_factory=dict)

async def execute(self, query: str) -> None:
pass
async def execute(self, query: str) -> list[dict[str, Any]]:
return [{'id': 123, 'name': 'John Doe'}]

async def store_forecast(self, user_id: int, forecast: str) -> None:
self._forecasts[user_id] = forecast
Expand Down Expand Up @@ -129,6 +132,7 @@ def test_docs_examples(
mocker: MockerFixture,
client_with_handler: ClientWithHandler,
env: TestEnv,
tmp_path: Path,
):
mocker.patch('pydantic_ai.agent.models.infer_model', side_effect=mock_infer_model)
mocker.patch('pydantic_ai._utils.group_by_temporal', side_effect=mock_group_by_temporal)
Expand All @@ -145,6 +149,14 @@ def test_docs_examples(
env.set('GROQ_API_KEY', 'testing')

prefix_settings = example.prefix_settings()
opt_title = prefix_settings.get('title')
cwd = Path.cwd()

if opt_title == 'test_sql_app.py':
os.chdir(tmp_path)
examples = [{'request': f'sql prompt {i}', 'sql': f'SELECT {i}'} for i in range(15)]
with (tmp_path / 'examples.json').open('w') as f:
json.dump(examples, f)

ruff_ignore: list[str] = ['D']
# `from bank_database import DatabaseConn` wrongly sorted in imports
Expand All @@ -153,7 +165,7 @@ def test_docs_examples(
ruff_ignore.append('I001')

line_length = 88
if prefix_settings.get('title') in ('streamed_hello_world.py', 'streamed_user_profile.py'):
if opt_title in ('streamed_hello_world.py', 'streamed_user_profile.py'):
line_length = 120

eval_example.set_config(ruff_ignore=ruff_ignore, target_version='py39', line_length=line_length)
Expand All @@ -173,8 +185,8 @@ def test_docs_examples(
eval_example.lint(example)
module_dict = eval_example.run_print_check(example, call=call_name)

debug(prefix_settings)
if title := prefix_settings.get('title'):
os.chdir(cwd)
if title := opt_title:
if title.endswith('.py'):
module_name = title[:-3]
sys.modules[module_name] = module = ModuleType(module_name)
Expand Down Expand Up @@ -275,6 +287,9 @@ async def model_logic(messages: list[Message], info: AgentInfo) -> ModelAnyRespo
else:
return ModelStructuredResponse(calls=[response])

if re.fullmatch(r'sql prompt \d+', m.content):
return ModelTextResponse(content='SELECT 1')

elif m.role == 'tool-return' and m.tool_name == 'roulette_wheel':
win = m.content == 'winner'
return ModelStructuredResponse(calls=[ToolCall(tool_name='final_result', args=ArgsDict({'response': win}))])
Expand Down