Skip to content

Commit

Permalink
refactor: create log_wandb method
Browse files Browse the repository at this point in the history
  • Loading branch information
honghanhh committed Oct 30, 2024
1 parent a797689 commit 1424987
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions lib/questions_eval/run_mimoracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,19 @@ def merge_evaluate_df(df: pd.DataFrame, row: pd.Series, evaluation_chain) -> pd.
return df_joined


def log_wandb(df: pd.DataFrame) -> dict:
log_dict = {
f"{stat}/score/{score_type}": (
df[f"{score_type}"].agg(stat)
if stat == "sum"
else df[f"{score_type}"].agg("sum") / len(df)
)
for stat in ["sum", "mean"]
for score_type in ["consistency", "conformity", "coverage"]
}
return log_dict


@hydra.main(config_path="./configs", config_name="run_mimoracle.yaml")
def main(cfg: DictConfig):
# Initialize WandB and log the models
Expand Down Expand Up @@ -283,15 +296,7 @@ def main(cfg: DictConfig):
df_joined = merge_evaluate_df(df, df_questions, evaluation_chain)

# Log results in wandb
log_dict = {
f"{stat}/score/{score_type}": (
df_joined[f"{score_type}"].agg(stat)
if stat == "sum"
else df_joined[f"{score_type}"].agg("sum") / len(df_joined)
)
for stat in ["sum", "mean"]
for score_type in ["consistency", "conformity", "coverage"]
}
log_dict = log_wandb(df_joined)
for key, value in log_dict.items():
wandb.run.summary[key] = value
wandb.log({"dataset/evaluation_mimoracle_gpt4o_retest": wandb.Table(dataframe=df_joined)})
Expand Down

0 comments on commit 1424987

Please sign in to comment.