Skip to content

Commit

Permalink
plot model families
Browse files Browse the repository at this point in the history
  • Loading branch information
slobentanzer committed Aug 9, 2024
1 parent 46ce9a6 commit 6a7693f
Show file tree
Hide file tree
Showing 13 changed files with 38 additions and 2 deletions.
Binary file modified docs/images/boxplot-naive-vs-biochatter.pdf
Binary file not shown.
Binary file modified docs/images/boxplot-text2cypher.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/images/dotplot-per-task.pdf
Binary file not shown.
Binary file modified docs/images/dotplot-per-task.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/images/scatter-per-quantisation-name.pdf
Binary file not shown.
Binary file modified docs/images/scatter-per-quantisation-name.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/images/scatter-quantisation-accuracy.pdf
Binary file not shown.
Binary file modified docs/images/scatter-size-accuracy.pdf
Binary file not shown.
Binary file modified docs/images/stripplot-extraction-tasks.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/images/stripplot-per-model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/images/stripplot-rag-tasks.pdf
Binary file not shown.
Binary file modified docs/images/stripplot-rag-tasks.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
40 changes: 38 additions & 2 deletions docs/scripts/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ def plot_text2cypher():
"""
Get entity_selection, relationship_selection, property_selection,
property_exists, and end_to_end_query_generation results files, combine and
preprocess them and plot the accuracy for each model as a boxplot.
property_exists, query_generation, and end_to_end_query_generation results
files, combine and preprocess them and plot the accuracy for each model as a
boxplot.
"""
entity_selection = pd.read_csv("benchmark/results/entity_selection.csv")
Expand All @@ -66,6 +67,8 @@ def plot_text2cypher():
property_selection["task"] = "property_selection"
property_exists = pd.read_csv("benchmark/results/property_exists.csv")
property_exists["task"] = "property_exists"
query_generation = pd.read_csv("benchmark/results/query_generation.csv")
query_generation["task"] = "query_generation"
end_to_end_query_generation = pd.read_csv(
"benchmark/results/end_to_end_query_generation.csv"
)
Expand All @@ -78,6 +81,7 @@ def plot_text2cypher():
relationship_selection,
property_selection,
property_exists,
query_generation,
end_to_end_query_generation,
]
)
Expand All @@ -103,14 +107,46 @@ def plot_text2cypher():
)
)

results["model"] = results["model_name"].apply(lambda x: x.split(":")[0])
# create labels: openhermes, llama-3, gpt, based on model name, for all
# other models, use "other open source"
results["model_family"] = results["model"].apply(
lambda x: (
"openhermes"
if "openhermes" in x
else (
"llama-3"
if "llama-3" in x
else "gpt" if "gpt" in x else "other open source"
)
)
)

# order task by median accuracy ascending
task_order = (
results.groupby("task")["accuracy"].median().sort_values().index
)

# order model_family by median accuracy ascending within each task
results["model_family"] = results["model_family"].astype(
pd.CategoricalDtype(
categories=["other open source", "llama-3", "openhermes", "gpt"],
ordered=True,
)
)

# plot results per task
sns.set_theme(style="whitegrid")
plt.figure(figsize=(6, 4))
plt.xticks(rotation=45, ha="right")
sns.boxplot(
x="task",
y="accuracy",
hue="model_family",
data=results,
order=task_order,
)
plt.legend(bbox_to_anchor=(1, 1), loc="upper left")
plt.savefig(
"docs/images/boxplot-text2cypher.png",
bbox_inches="tight",
Expand Down

0 comments on commit 6a7693f

Please sign in to comment.