Skip to content

Commit c65c67a

Browse files
committed
address comments of DouweM in the data analyst example
1 parent 62281e7 commit c65c67a

File tree

1 file changed

+34
-23
lines changed

1 file changed

+34
-23
lines changed

examples/pydantic_ai_examples/data_analyst.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,30 @@
11
from dataclasses import dataclass
22

3+
import datasets
4+
import duckdb
5+
import pandas as pd
36
from devtools import debug
47

58
from pydantic_ai import Agent, ModelRetry, RunContext
69

7-
try:
8-
import datasets
9-
import duckdb
10-
import pandas as pd
11-
except ImportError as e:
12-
raise ImportError(
13-
'Please install both duckdb and pandas.\n'
14-
'- pip: `pip install duckdb pandas\n'
15-
'- uv: `uv pip install duckdb pandas'
16-
) from e
17-
1810

1911
@dataclass
2012
class AnalystAgentDeps:
2113
output: dict[str, pd.DataFrame]
2214

15+
def store(self, value: pd.DataFrame) -> str:
16+
"""Store the output in deps and return the reference such as Out[1] to be used by the LLM."""
17+
ref = f'Out[{len(self.output) + 1}]'
18+
self.output[ref] = value
19+
return ref
20+
21+
def get(self, ref: str) -> pd.DataFrame:
22+
if ref not in self.output:
23+
raise ModelRetry(
24+
f'Error: {ref} is not a valid variable reference. Check the previous messages and try again.'
25+
)
26+
return self.output[ref]
27+
2328

2429
analyst_agent = Agent(
2530
'openai:gpt-4o',
@@ -41,6 +46,7 @@ def load_dataset(
4146
path: name of the dataset in the form of `<user_name>/<dataset_name>`
4247
split: load the split of the dataset (default: "train")
4348
"""
49+
# begin load data from hf
4450
builder = datasets.load_dataset_builder(path) # pyright: ignore[reportUnknownMemberType]
4551
splits: dict[str, datasets.SplitInfo] = builder.info.splits or {} # pyright: ignore[reportUnknownMemberType]
4652
if split not in splits:
@@ -53,14 +59,19 @@ def load_dataset(
5359
assert isinstance(dataset, datasets.Dataset)
5460
dataframe = dataset.to_pandas()
5561
assert isinstance(dataframe, pd.DataFrame)
56-
ref = f'Out[{len(ctx.deps.output) + 1}]'
57-
ctx.deps.output[ref] = dataframe
58-
output = [f'Loaded the dataset as `{ref}`.']
59-
if dataset.info.description:
60-
output.append(f'Description: {dataset.info.description}')
61-
if dataset.info.features:
62-
output.append(f'Features: {dataset.info.features!r}')
63-
return '\n'.join(output)
62+
# end load data from hf
63+
64+
# store the dataframe in the deps and get a ref like "Out[1]"
65+
ref = ctx.deps.store(dataframe)
66+
# construct a summary of the loaded dataset
67+
output = [
68+
f'Loaded the dataset as `{ref}`.',
69+
f'Description: {dataset.info.description}'
70+
if dataset.info.description
71+
else None,
72+
f'Features: {dataset.info.features!r}' if dataset.info.features else None,
73+
]
74+
return '\n'.join(filter(None, output))
6475

6576

6677
@analyst_agent.tool
@@ -76,15 +87,15 @@ def run_duckdb(ctx: RunContext[AnalystAgentDeps], dataset: str, sql: str) -> str
7687
"""
7788
data = ctx.deps.output[dataset]
7889
result = duckdb.query_df(df=data, virtual_table_name='dataset', sql_query=sql)
79-
ref = f'Out[{len(ctx.deps.output) + 1}]'
80-
ctx.deps.output[ref] = result.df() # pyright: ignore[reportUnknownMemberType]
90+
# pass the result as ref (because DuckDB SQL can select many rows, creating another huge dataframe)
91+
ref = ctx.deps.store(result.df()) # pyright: ignore[reportUnknownMemberType]
8192
return f'Executed SQL, result is `{ref}`'
8293

8394

8495
@analyst_agent.tool
8596
def display(ctx: RunContext[AnalystAgentDeps], name: str) -> str:
86-
"""Display the dataframe at most 5 rows."""
87-
dataset = ctx.deps.output[name]
97+
"""Display at most 5 rows of the dataframe ."""
98+
dataset = ctx.deps.get(name)
8899
return dataset.head().to_string() # pyright: ignore[reportUnknownMemberType]
89100

90101

0 commit comments

Comments
 (0)