1- from dataclasses import dataclass
1+ from dataclasses import dataclass , field
22
33import datasets
44import duckdb
99
1010@dataclass
1111class AnalystAgentDeps :
12- output : dict [str , pd .DataFrame ]
12+ output : dict [str , pd .DataFrame ] = field ( default_factory = dict )
1313
1414 def store (self , value : pd .DataFrame ) -> str :
1515 """Store the output in deps and return the reference such as Out[1] to be used by the LLM."""
@@ -84,7 +84,7 @@ def run_duckdb(ctx: RunContext[AnalystAgentDeps], dataset: str, sql: str) -> str
8484 dataset: reference string to the DataFrame
8585 sql: the query to be executed using DuckDB
8686 """
87- data = ctx .deps .output [ dataset ]
87+ data = ctx .deps .get ( dataset )
8888 result = duckdb .query_df (df = data , virtual_table_name = 'dataset' , sql_query = sql )
8989 # pass the result as ref (because DuckDB SQL can select many rows, creating another huge dataframe)
9090 ref = ctx .deps .store (result .df ()) # pyright: ignore[reportUnknownMemberType]
@@ -93,13 +93,13 @@ def run_duckdb(ctx: RunContext[AnalystAgentDeps], dataset: str, sql: str) -> str
9393
9494@analyst_agent .tool
9595def display (ctx : RunContext [AnalystAgentDeps ], name : str ) -> str :
96- """Display at most 5 rows of the dataframe ."""
96+ """Display at most 5 rows of the dataframe."""
9797 dataset = ctx .deps .get (name )
9898 return dataset .head ().to_string () # pyright: ignore[reportUnknownMemberType]
9999
100100
101101if __name__ == '__main__' :
102- deps = AnalystAgentDeps (output = {} )
102+ deps = AnalystAgentDeps ()
103103 result = analyst_agent .run_sync (
104104 user_prompt = 'Count how many negative comments are there in the dataset `cornell-movie-review-data/rotten_tomatoes`' ,
105105 deps = deps ,
0 commit comments