From f766238b99ddf73ca0bd1e825404868c6f6649d8 Mon Sep 17 00:00:00 2001 From: Way2Learn <118058822+Xisen-Wang@users.noreply.github.com> Date: Mon, 26 Aug 2024 16:38:49 +0800 Subject: [PATCH] feat: update model_experiment.py to support basic EDA (#220) * Update model_experiment.py to support basic eda It looks into the data first before the proposal. * Update model_experiment.py Revised linting * Update model_experiment.py by fixing sorting order * Update model_experiment.py for black linting * Update model_experiment.py * Update model_experiment.py * Update model_experiment.py * Update model_experiment.py --------- Co-authored-by: WinstonLiyt <104308117+WinstonLiyt@users.noreply.github.com> --- .../kaggle/experiment/model_experiment.py | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/rdagent/scenarios/kaggle/experiment/model_experiment.py b/rdagent/scenarios/kaggle/experiment/model_experiment.py index 889b161b..8cab9d21 100644 --- a/rdagent/scenarios/kaggle/experiment/model_experiment.py +++ b/rdagent/scenarios/kaggle/experiment/model_experiment.py @@ -1,6 +1,7 @@ import json from pathlib import Path +import pandas as pd from jinja2 import Environment, StrictUndefined from rdagent.components.coder.model_coder.model import ( @@ -77,12 +78,34 @@ def background(self) -> str: competition_features=self.competition_features, ) ) - return background_prompt @property def source_data(self) -> str: - raise NotImplementedError("source_data is not implemented") + kaggle_conf = KGDockerConf() + data_path = Path(f"{kaggle_conf.share_data_path}/{self.competition}") + + csv_files = list(data_path.glob("*.csv")) + + if not csv_files: + return "No CSV files found in the specified path." + + dataset = pd.concat([pd.read_csv(file) for file in csv_files], ignore_index=True) + + simple_eda = dataset.info(buf=None) # Capture the info output + data_shape = dataset.shape + data_head = dataset.head() + + eda = ( + f"Basic Info about the data:\n{simple_eda}\n" + f"Shape of the dataset: {data_shape}\n" + f"Sample Data:\n{data_head}\n" + ) + + data_description = self.competition_descriptions.get("Data Description", "No description provided") + eda += f"\nData Description:\n{data_description}" + + return eda @property def output_format(self) -> str: