11from dataclasses import dataclass
22
3+ import datasets
4+ import duckdb
5+ import pandas as pd
36from devtools import debug
47
58from pydantic_ai import Agent , ModelRetry , RunContext
69
7- try :
8- import datasets
9- import duckdb
10- import pandas as pd
11- except ImportError as e :
12- raise ImportError (
13- 'Please install both duckdb and pandas.\n '
14- '- pip: `pip install duckdb pandas\n '
15- '- uv: `uv pip install duckdb pandas'
16- ) from e
17-
1810
1911@dataclass
2012class AnalystAgentDeps :
2113 output : dict [str , pd .DataFrame ]
2214
15+ def store (self , value : pd .DataFrame ) -> str :
16+ """Store the output in deps and return the reference such as Out[1] to be used by the LLM."""
17+ ref = f'Out[{ len (self .output ) + 1 } ]'
18+ self .output [ref ] = value
19+ return ref
20+
21+ def get (self , ref : str ) -> pd .DataFrame :
22+ if ref not in self .output :
23+ raise ModelRetry (
24+ f'Error: { ref } is not a valid variable reference. Check the previous messages and try again.'
25+ )
26+ return self .output [ref ]
27+
2328
2429analyst_agent = Agent (
2530 'openai:gpt-4o' ,
@@ -41,6 +46,7 @@ def load_dataset(
4146 path: name of the dataset in the form of `<user_name>/<dataset_name>`
4247 split: load the split of the dataset (default: "train")
4348 """
49+ # begin load data from hf
4450 builder = datasets .load_dataset_builder (path ) # pyright: ignore[reportUnknownMemberType]
4551 splits : dict [str , datasets .SplitInfo ] = builder .info .splits or {} # pyright: ignore[reportUnknownMemberType]
4652 if split not in splits :
@@ -53,14 +59,19 @@ def load_dataset(
5359 assert isinstance (dataset , datasets .Dataset )
5460 dataframe = dataset .to_pandas ()
5561 assert isinstance (dataframe , pd .DataFrame )
56- ref = f'Out[{ len (ctx .deps .output ) + 1 } ]'
57- ctx .deps .output [ref ] = dataframe
58- output = [f'Loaded the dataset as `{ ref } `.' ]
59- if dataset .info .description :
60- output .append (f'Description: { dataset .info .description } ' )
61- if dataset .info .features :
62- output .append (f'Features: { dataset .info .features !r} ' )
63- return '\n ' .join (output )
62+ # end load data from hf
63+
64+ # store the dataframe in the deps and get a ref like "Out[1]"
65+ ref = ctx .deps .store (dataframe )
66+ # construct a summary of the loaded dataset
67+ output = [
68+ f'Loaded the dataset as `{ ref } `.' ,
69+ f'Description: { dataset .info .description } '
70+ if dataset .info .description
71+ else None ,
72+ f'Features: { dataset .info .features !r} ' if dataset .info .features else None ,
73+ ]
74+ return '\n ' .join (filter (None , output ))
6475
6576
6677@analyst_agent .tool
@@ -76,15 +87,15 @@ def run_duckdb(ctx: RunContext[AnalystAgentDeps], dataset: str, sql: str) -> str
7687 """
7788 data = ctx .deps .output [dataset ]
7889 result = duckdb .query_df (df = data , virtual_table_name = 'dataset' , sql_query = sql )
79- ref = f'Out[ { len ( ctx . deps . output ) + 1 } ]'
80- ctx . deps . output [ ref ] = result .df () # pyright: ignore[reportUnknownMemberType]
90+ # pass the result as ref (because DuckDB SQL can select many rows, creating another huge dataframe)
91+ ref = ctx . deps . store ( result .df () ) # pyright: ignore[reportUnknownMemberType]
8192 return f'Executed SQL, result is `{ ref } `'
8293
8394
8495@analyst_agent .tool
8596def display (ctx : RunContext [AnalystAgentDeps ], name : str ) -> str :
86- """Display the dataframe at most 5 rows."""
87- dataset = ctx .deps .output [ name ]
97+ """Display at most 5 rows of the dataframe ."""
98+ dataset = ctx .deps .get ( name )
8899 return dataset .head ().to_string () # pyright: ignore[reportUnknownMemberType]
89100
90101
0 commit comments