Skip to content

Commit

Permalink
add correlation plots w/ gemma
Browse files Browse the repository at this point in the history
  • Loading branch information
dtch1997 committed Aug 6, 2024
1 parent 1b26240 commit 3e28811
Show file tree
Hide file tree
Showing 6 changed files with 265 additions and 115 deletions.
Binary file not shown.
Binary file not shown.
118 changes: 118 additions & 0 deletions repepo/paper/helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pandas as pd

dataset_full_names_to_short_names = {
"willingness-to-use-physical-force-to-achieve-benevolent-goals": "willing-force-for-benev-goals",
"willingness-to-use-social-engineering-to-achieve-its-goals": "willing-soc-eng-for-goals",
Expand All @@ -13,3 +15,119 @@
"self-awareness-training-web-gpt": "self-aware-web-gpt",
"believes-abortion-should-be-illegal": "believes-anti-abortion",
}


def compute_steerability_df(df: pd.DataFrame, model_name: str):
"""Get a dataframe with various ID / OOD settings."""
# Calculate overall steerability by dataset.
# Calculate steerability within each flavour
mean_slope = df.groupby(["dataset_name", "steering_label", "dataset_label"])[
"slope"
].mean()
df = df.merge(
mean_slope,
on=["dataset_name", "steering_label", "dataset_label"],
suffixes=("", "_mean"),
)

# BASE -> BASE
steerability_id_df = df[
(df.steering_label == "baseline")
& (df.dataset_label == "baseline")
& (df.multiplier == 0)
][["dataset_name", "slope_mean"]].drop_duplicates()
# Rename 'slope_mean' to 'steerability'
steerability_id_df = steerability_id_df.rename(
columns={"slope_mean": "steerability"}
)

# SYS_POS -> SYS_NEG
steerability_ood_df = df[
(df.steering_label == "SYS_positive")
& (df.dataset_label == "SYS_negative")
& (df.multiplier == 0)
][["dataset_name", "slope_mean"]].drop_duplicates()
# Rename 'slope_mean' to 'steerability'
steerability_ood_df = steerability_ood_df.rename(
columns={"slope_mean": "steerability"}
)

# BASE -> USER_NEG
steerability_base_to_user_neg_df = df[
(df.steering_label == "baseline")
& (df.dataset_label == "PT_negative")
& (df.multiplier == 0)
][["dataset_name", "slope_mean"]].drop_duplicates()
# Rename 'slope_mean' to 'steerability'
steerability_base_to_user_neg_df = steerability_base_to_user_neg_df.rename(
columns={"slope_mean": "steerability_base_to_user_neg"}
)

# BASE -> USER_POS
steerability_base_to_user_pos_df = df[
(df.steering_label == "baseline")
& (df.dataset_label == "PT_positive")
& (df.multiplier == 0)
][["dataset_name", "slope_mean"]].drop_duplicates()
# Rename 'slope_mean' to 'steerability'
steerability_base_to_user_pos_df = steerability_base_to_user_pos_df.rename(
columns={"slope_mean": "steerability_base_to_user_pos"}
)

# SYS_POS -> USER_NEG
steerability_ood_to_user_neg_df = df[
(df.steering_label == "SYS_positive")
& (df.dataset_label == "PT_negative")
& (df.multiplier == 0)
][["dataset_name", "slope_mean"]].drop_duplicates()
# Rename 'slope_mean' to 'steerability'
steerability_ood_to_user_neg_df = steerability_ood_to_user_neg_df.rename(
columns={"slope_mean": "steerability_ood_to_user_neg"}
)

# SYS_NEG -> USER_POS
steerability_ood_to_user_pos_df = df[
(df.steering_label == "SYS_negative")
& (df.dataset_label == "PT_positive")
& (df.multiplier == 0)
][["dataset_name", "slope_mean"]].drop_duplicates()
# Rename 'slope_mean' to 'steerability'
steerability_ood_to_user_pos_df = steerability_ood_to_user_pos_df.rename(
columns={"slope_mean": "steerability_ood_to_user_pos"}
)

# Merge the dataframes
steerability_df = steerability_id_df.merge(
steerability_ood_df, on="dataset_name", suffixes=("_id", "_ood")
)
steerability_df = steerability_df.merge(
steerability_base_to_user_neg_df, on="dataset_name"
)
steerability_df = steerability_df.merge(
steerability_base_to_user_pos_df, on="dataset_name"
)
steerability_df = steerability_df.merge(
steerability_ood_to_user_neg_df, on="dataset_name"
)
steerability_df = steerability_df.merge(
steerability_ood_to_user_pos_df, on="dataset_name"
)

print(steerability_df.columns)

# Save the dataframe for plotting between models
steerability_df.to_parquet(
f"{model_name}_steerability_summary.parquet.gzip", compression="gzip"
)
steerability_df = steerability_df.rename(
columns={
"steerability_id": "BASE -> BASE",
"steerability_ood": "SYS_POS -> SYS_NEG",
"steerability_base_to_user_neg": "BASE -> USER_NEG",
"steerability_base_to_user_pos": "BASE -> USER_POS",
"steerability_ood_to_user_neg": "SYS_POS -> USER_NEG",
"steerability_ood_to_user_pos": "SYS_NEG -> USER_POS",
}
)

return steerability_df
133 changes: 133 additions & 0 deletions repepo/paper/make_figures_rebuttals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import seaborn as sns
import json
import matplotlib.pyplot as plt
from sklearn.calibration import column_or_1d
from repepo.paper.preprocess_results import compute_steerability
from repepo.paper.helpers import dataset_full_names_to_short_names

Expand Down Expand Up @@ -102,3 +103,135 @@ def make_plot_for_steerability_vs_mse(df, selected_datasets):


make_plot_for_steerability_vs_mse(df, selected_datasets)

# %%
# Load the steering results for the selected datasets
import pandas as pd

from repepo.paper.helpers import compute_steerability_df
from repepo.paper.utils import get_model_full_name

llama_df = pd.read_parquet("llama7b_steerability.parquet.gzip").drop_duplicates()
llama_steerability_df = compute_steerability_df(llama_df, "llama7b")

qwen_df = pd.read_parquet("qwen_steerability.parquet.gzip").drop_duplicates()
qwen_steerability_df = compute_steerability_df(qwen_df, "qwen")

gemma_df = pd.read_parquet("gemma_steerability.parquet.gzip").drop_duplicates()
gemma_steerability_df = compute_steerability_df(gemma_df, "gemma")

# %%
print(len(gemma_steerability_df))
print(gemma_steerability_df.columns)
print(qwen_steerability_df.columns)
print(llama_steerability_df.columns)

# %%
# Merge the dataframes


def make_cross_model_df(
llama_steerability_df,
qwen_steerability_df,
gemma_steerability_df,
select_columns,
):
steerability_df = (
llama_steerability_df[select_columns]
.merge(
qwen_steerability_df[select_columns],
on="dataset_name",
suffixes=("_llama", "_qwen"),
)
.merge(
gemma_steerability_df[select_columns],
on="dataset_name",
suffixes=("", "_gemma"),
)
)
# TODO: Why does the gemma name not get updated correctly?
# NOTE: Manually update
steerability_df = steerability_df.rename(
columns={
"BASE -> BASE": "BASE -> BASE_gemma",
"SYS_POS -> USER_NEG": "SYS_POS -> USER_NEG_gemma",
}
)
return steerability_df


id_df = make_cross_model_df(
llama_steerability_df,
qwen_steerability_df,
gemma_steerability_df,
select_columns=["dataset_name", "BASE -> BASE"],
)
ood_df = make_cross_model_df(
llama_steerability_df,
qwen_steerability_df,
gemma_steerability_df,
select_columns=["dataset_name", "SYS_POS -> USER_NEG"],
)

print(id_df.columns)
print(ood_df.columns)

# %%
# Scatterplot matrix of steerability between different models, both ID and OOD
import matplotlib.pyplot as plt
import seaborn as sns
from repepo.paper.utils import get_model_full_name
from scipy.stats import spearmanr

sns.set_theme()


def add_xy_line(xdata, ydata, xy_min: float, xy_max: float, **kwargs):
# Add the diagonal line
plt.xlim(xy_min, xy_max)
plt.ylim(xy_min, xy_max)
plt.axline(
(xy_min, xy_min), (xy_max, xy_max), color="black", linestyle="--", linewidth=2
)


def add_textbox(xdata, ydata, xy_min: float, xy_max: float, **kwargs):

spearman_corr, spearman_p = spearmanr(
xdata, ydata, nan_policy="omit"
) # Spearman correlation
plt.text(0, xy_max - 0.4, f"Corr: {spearman_corr:.2f}", fontsize=12)


def make_scatterplot_matrix(df, type: str, title):
columns = [
get_model_full_name("llama7b"),
get_model_full_name("qwen"),
get_model_full_name("gemma"),
]
# Rename columns
df = df.rename(
columns={
# ID
"BASE -> BASE_llama": get_model_full_name("llama7b"),
"BASE -> BASE_qwen": get_model_full_name("qwen"),
"BASE -> BASE_gemma": get_model_full_name("gemma"),
# OOD
"SYS_POS -> USER_NEG_llama": get_model_full_name("llama7b"),
"SYS_POS -> USER_NEG_qwen": get_model_full_name("qwen"),
"SYS_POS -> USER_NEG_gemma": get_model_full_name("gemma"),
}
)

grid = sns.pairplot(df, vars=columns)
grid.map_offdiag(add_xy_line, xy_min=-2.0, xy_max=6.0)
fig = grid.figure
fig.suptitle(title)
fig.tight_layout()
grid.map_offdiag(add_textbox, xy_min=-2.0, xy_max=6.0)
fig.show()
fig.savefig(f"figures/all_models_scatterplot_matrix_{type}.pdf")


make_scatterplot_matrix(id_df, "id", "Steerability between different models (ID)")
make_scatterplot_matrix(ood_df, "ood", "Steerability between different models (OOD)")
Loading

0 comments on commit 3e28811

Please sign in to comment.