From b30adb9ac6026ea40b130e8b7babfad26f5724b2 Mon Sep 17 00:00:00 2001 From: Shayekh Bin Islam Date: Fri, 11 Oct 2024 20:19:10 +0600 Subject: [PATCH] maple/agreement codes --- analysis/avg_agreement_final.py | 92 +++++++++++++++++++++++++++++ analysis/maple_results.py | 101 ++++++++++++++++++++++++++++++++ 2 files changed, 193 insertions(+) create mode 100644 analysis/avg_agreement_final.py create mode 100644 analysis/maple_results.py diff --git a/analysis/avg_agreement_final.py b/analysis/avg_agreement_final.py new file mode 100644 index 0000000..cb93b89 --- /dev/null +++ b/analysis/avg_agreement_final.py @@ -0,0 +1,92 @@ +import seaborn as sns +import matplotlib.pyplot as plt +import numpy as np + +data = { + "meta-llama/Meta-Llama-3.1-8B-Instruct": [ + 0.3533086666014079, + 0.052422082615756406 + ], + "cohere/c4ai-aya-23-35b": [ + 0.43767196047824003, + 0.026040919354464294 + ], + "cohere/c4ai-aya-23-8b": [ + 0.013483014909052663, + 0.03363706833599835 + ], + "cohere/command-r-08-2024": [ + 0.374457668650282, + 0.02926089754079793 + ], + "cohere/command-r-plus-08-2024": [ + 0.3830841816733316, + 0.020185255968455686 + ], + "google/gemma-1.1-7b-it": [ + 0.5190375637539242, + 0.027757722654111305 + ], + "google/gemma-2-9b-it": [ + 0.5181663123111222, + 0.031090119385244894 + ], + "meta-llama/Meta-Llama-3-70B-Instruct": [ + 0.5685224105896568, + 0.04853344616275034 + ], + "meta-llama/Meta-Llama-3-8B-Instruct": [ + 0.37936948540837095, + 0.032172769265151994 + ], + "meta-llama/Meta-Llama-3.1-70B-Instruct": [ + 0.603536768244583, + 0.027191895488989915 + ], + "mistralai/Mistral-7B-Instruct-v0.2": [ + 0.4071166722276529, + 0.04577594028555328 + ], + "mistralai/Mistral-7B-Instruct-v0.3": [ + 0.41195018984687265, + 0.056184679972755454 + ], + "openai/gpt-4-turbo-2024-04-09": [ + 0.6106943361444249, + 0.02932446842558468 + ], + "openai/gpt-4o-2024-05-13": [ + 0.5833874065757011, + 0.023695391445384514 + ] +} + +sorted_data = dict(sorted(data.items(), key=lambda item: item[1][0])) +labels_sorted = list(sorted_data.keys()) +means_sorted = [v[0] for v in sorted_data.values()] +std_devs_sorted = [v[1] for v in sorted_data.values()] + +sns.set(style="whitegrid") +palette = sns.color_palette("coolwarm", len(labels_sorted)) + +plt.figure(figsize=(10, 6)) +x_pos_sorted = np.arange(len(labels_sorted)) + +ax1 = sns.barplot(x=x_pos_sorted, y=means_sorted, palette=palette, errorbar=None) +plt.errorbar(x_pos_sorted, means_sorted, yerr=std_devs_sorted, fmt='none', c='black', capsize=5) + +ax1.spines['top'].set_color('black') +ax1.spines['right'].set_color('black') +ax1.spines['left'].set_color('black') +ax1.spines['bottom'].set_color('black') +for spine in ax1.spines.values(): + spine.set_linewidth(2) # Make the border thicker + +plt.ylim(0, 0.8) + +plt.xticks(x_pos_sorted, labels_sorted, rotation=90) +plt.ylabel("Cohen's Kappa") +plt.title('Average Inner-Model Agreement Across Languages') + +plt.tight_layout() +plt.savefig(f"./innermodel_agreement.pdf", bbox_inches='tight') \ No newline at end of file diff --git a/analysis/maple_results.py b/analysis/maple_results.py new file mode 100644 index 0000000..45ee3d0 --- /dev/null +++ b/analysis/maple_results.py @@ -0,0 +1,101 @@ +import json +from pathlib import Path + +import argparse +import logging +from pathlib import Path +from typing import Optional + +import pandas as pd +import seaborn as sns +import matplotlib.pyplot as plt +from huggingface_hub import snapshot_download +import datasets +import json + +import numpy as np +import matplotlib.pyplot as plt +from itertools import combinations +from collections import defaultdict + + +FONT_SIZES = {"small": 12, "medium": 16, "large": 18} + +PLOT_PARAMS = { + "font.family": "serif", + "font.serif": ["Times New Roman", "STIX"], + "font.size": FONT_SIZES.get("medium"), + "axes.titlesize": FONT_SIZES.get("large"), + "axes.labelsize": FONT_SIZES.get("large"), + "xtick.labelsize": FONT_SIZES.get("large"), + "ytick.labelsize": FONT_SIZES.get("small"), + "legend.fontsize": FONT_SIZES.get("medium"), + "figure.titlesize": FONT_SIZES.get("medium"), + "text.usetex": False, +} + +logging.basicConfig(level=logging.INFO) + +plt.rcParams.update(PLOT_PARAMS) + +def load_json(json_file_path): + with open(json_file_path, "r") as file: + json_data = json.load(file) + return json_data + +results_dir = 'data/eval-results-maple' +results_path = Path(results_dir) + +results_all = [] +for result_file in results_path.glob("*.json"): + raw_results = load_json(result_file) + if "leaderboard" in raw_results.keys(): + model_id = raw_results["model"] + subset_results = raw_results['subset'] + overall = raw_results['scores']['accuracy'] + remove_key = ['model', 'model_type', 'chat_template'] + for key in remove_key: + del subset_results[key] + elif "subset_results" in raw_results.keys(): + model_id = raw_results["model"] + subset_results = raw_results['subset_results'] + overall = raw_results['accuracy'] + else: + model_id = raw_results["model"] + subset_results = raw_results['extra_results'] + overall = raw_results['accuracy'] + # print(model_id, overall) + # print("\t", subset_results) + # results_all.append([model_id, overall, subset_results]) + results_all.append({'Model': model_id, 'Avg': overall, **subset_results}) + + # import ipdb; ipdb.set_trace() + +TOP = 10 +# results_all.sort(key=lambda x: x[1], reverse=True) +# results_all = results_all[:TOP] +# print(results_all) + +df_results = pd.DataFrame(results_all) +df_results = df_results.sort_values(by='Avg', ascending=False).reset_index(drop=True) +df_results = df_results.head(10).reset_index(drop=True) + +df_results.columns = df_results.columns.str.replace('^maple-', '', regex=True) +df_results = df_results.set_index("Model") +df_results = df_results * 100 +fig, ax = plt.subplots(1, 1, figsize=(18, 5)) + +sns.heatmap(df_results, ax=ax, cmap="YlGn", annot=True, annot_kws={"size": 16}, + fmt=".1f", cbar=False) + +ax.xaxis.set_ticks_position("top") +ax.tick_params(axis="x", labelrotation=45) +ax.set_ylabel("") +ax.set_yticklabels([f"{model} " for model in df_results.index]) + +plt.tight_layout() + +plt.savefig("plots/maple.pdf", bbox_inches="tight") +# import ipdb; ipdb.set_trace() + +