diff --git a/lib/questions_eval/run_mimoracle.py b/lib/questions_eval/run_mimoracle.py index f666736..e51b00c 100644 --- a/lib/questions_eval/run_mimoracle.py +++ b/lib/questions_eval/run_mimoracle.py @@ -249,6 +249,19 @@ def merge_evaluate_df(df: pd.DataFrame, row: pd.Series, evaluation_chain) -> pd. return df_joined +def log_wandb(df: pd.DataFrame) -> dict: + log_dict = { + f"{stat}/score/{score_type}": ( + df[f"{score_type}"].agg(stat) + if stat == "sum" + else df[f"{score_type}"].agg("sum") / len(df) + ) + for stat in ["sum", "mean"] + for score_type in ["consistency", "conformity", "coverage"] + } + return log_dict + + @hydra.main(config_path="./configs", config_name="run_mimoracle.yaml") def main(cfg: DictConfig): # Initialize WandB and log the models @@ -283,15 +296,7 @@ def main(cfg: DictConfig): df_joined = merge_evaluate_df(df, df_questions, evaluation_chain) # Log results in wandb - log_dict = { - f"{stat}/score/{score_type}": ( - df_joined[f"{score_type}"].agg(stat) - if stat == "sum" - else df_joined[f"{score_type}"].agg("sum") / len(df_joined) - ) - for stat in ["sum", "mean"] - for score_type in ["consistency", "conformity", "coverage"] - } + log_dict = log_wandb(df_joined) for key, value in log_dict.items(): wandb.run.summary[key] = value wandb.log({"dataset/evaluation_mimoracle_gpt4o_retest": wandb.Table(dataframe=df_joined)})