Skip to content

Commit

Permalink
Paper/preprocessing (#170)
Browse files Browse the repository at this point in the history
* add preprocessing script

* update figures

---------

Co-authored-by: Daniel CH Tan <dtch1997@users.noreply.github.com>
  • Loading branch information
dtch1997 and dtch1997 authored May 19, 2024
1 parent 453737d commit 307878f
Show file tree
Hide file tree
Showing 13 changed files with 1,222 additions and 91 deletions.
55 changes: 55 additions & 0 deletions repepo/paper/compare_steerability_between_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# %%
"""
Assumes you have run repepo.paper.make_figurse_steering_ood for both models
"""
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme()

# Load the steerability data
llama7b_df = pd.read_parquet('llama7b_steerability_summary.parquet.gzip')
llama7b_df = llama7b_df.drop_duplicates()
qwen_df = pd.read_parquet('qwen_steerability_summary.parquet.gzip')
qwen_df = qwen_df.drop_duplicates()

combined = llama7b_df.merge(qwen_df, on='dataset_name', suffixes=('_llama7b', '_qwen'))

# %%
# Correlation in gen gap
sns.regplot(data=combined, x='gap_qwen', y='gap_llama7b')

# %%
# Correlation in steerability
fig, ax = plt.subplots()
sns.regplot(data=combined, x='steerability_id_qwen', y='steerability_id_llama7b')
# Draw the x = y line
x = combined['steerability_id_qwen']
y = combined['steerability_id_llama7b']

min = x.min() if x.min() < y.min() else y.min()
max = x.max() if x.max() > y.max() else y.max()
ax.plot([min, max], [min, max], color='black', linestyle='--')
plt.xlabel('Qwen ID steerability')
plt.ylabel('Llama7b ID steerability')
plt.show()
# %%
# Correlation in ood steerability
fig, ax = plt.subplots()
sns.regplot(data=combined, x='steerability_ood_qwen', y='steerability_ood_llama7b')
# Draw the x = y line
x = combined['steerability_ood_qwen']
y = combined['steerability_ood_qwen']

min = x.min() if x.min() < y.min() else y.min()
max = x.max() if x.max() > y.max() else y.max()
ax.plot([min, max], [min, max], color='black', linestyle='--')
plt.xlabel('Qwen OOD steerability')
plt.ylabel('Llama7b OOD steerability')

fig.suptitle("Steerability OOD for Qwen and Llama7b")
fig.savefig('figures/steerability_correlation.png')
plt.show()
# %%
96 changes: 19 additions & 77 deletions repepo/paper/fig_logit_diff_vs_multiplier.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,85 +29,27 @@
"from repepo.core.evaluate import EvalResult, EvalPrediction\n",
"from repepo.experiments.persona_generalization import PersonaCrossSteeringExperimentResult\n",
"from repepo.experiments.get_datasets import get_all_prompts\n",
"from repepo.paper.utils import (\n",
" load_persona_cross_steering_experiment_result,\n",
" get_eval_result_sweep,\n",
" eval_result_sweep_as_df\n",
")\n",
"\n",
"EvalResultSweep = dict[float, EvalResult] # A sweep over a multiplier\n",
"\n",
"EXPERIMENT_DIR = pathlib.Path(Environ.ProjectDir) / 'experiments' / 'persona_generalization' / 'persona_generalization'\n",
"\n",
"def get_persona_cross_steering_experiment_result_path(\n",
" dataset_name: str,\n",
") -> pathlib.Path:\n",
" return EXPERIMENT_DIR / f\"{dataset_name}.pt\"\n",
"\n",
"def load_persona_cross_steering_experiment_result(\n",
" dataset_name: str, \n",
") -> PersonaCrossSteeringExperimentResult:\n",
" result_path = EXPERIMENT_DIR / f\"{dataset_name}.pt\"\n",
" return torch.load(result_path)\n",
"\n",
"def get_steering_vector(\n",
" persona_cross_steering_experiment_result: PersonaCrossSteeringExperimentResult,\n",
" steering_label: str = 'baseline',\n",
") -> SteeringVector:\n",
" return persona_cross_steering_experiment_result.steering_vectors[steering_label]\n",
"\n",
"def get_eval_result_sweep(\n",
" persona_cross_steering_experiment_result: PersonaCrossSteeringExperimentResult,\n",
" steering_label: str = 'baseline', # Label of the dataset used to train steering vector\n",
" dataset_label: str = 'baseline', # Label of the dataset used to evaluate the steering vector\n",
") -> EvalResultSweep:\n",
" \n",
" results = {}\n",
" cross_steering_result = persona_cross_steering_experiment_result.cross_steering_result\n",
" multipliers = list(cross_steering_result.steering.keys())\n",
"\n",
" dataset_idx = cross_steering_result.dataset_labels.index(dataset_label)\n",
" steering_idx = cross_steering_result.steering_labels.index(steering_label)\n",
" for multiplier in multipliers:\n",
" results[multiplier] = cross_steering_result.steering[multiplier][dataset_idx][steering_idx]\n",
" # add the zero result\n",
" results[0] = cross_steering_result.dataset_baselines[dataset_idx]\n",
" return results\n",
"\n",
"\n",
"# Functions to make pandas dataframes\n",
"def eval_prediction_as_dict(\n",
" prediction: EvalPrediction,\n",
"):\n",
" dict = {}\n",
" dict.update(prediction.metrics)\n",
" \n",
" if prediction.positive_output_prob is not None:\n",
" dict['test_example.positive.text'] = prediction.positive_output_prob.text\n",
" else:\n",
" dict['test_example.positive.text'] = None\n",
" if prediction.negative_output_prob is not None:\n",
" dict['test_example.negative.text'] = prediction.negative_output_prob.text\n",
" else:\n",
" dict['test_example.negative.text'] = None\n",
" return dict\n",
"\n",
"def eval_result_as_df(\n",
" eval_result: EvalResult,\n",
") -> pd.DataFrame:\n",
" # predictions\n",
" rows = []\n",
" for idx, pred in enumerate(eval_result.predictions): \n",
" dict = eval_prediction_as_dict(pred)\n",
" dict['test_example.idx'] = idx\n",
" rows.append(dict)\n",
" # TODO: metrics? \n",
" return pd.DataFrame(rows)\n",
"EvalResultSweep = dict[float, EvalResult] # A sweep over a multiplier"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# model = 'llama7b' \n",
"model = 'qwen'\n",
"\n",
"def eval_result_sweep_as_df(\n",
" eval_results: dict[float, EvalResult],\n",
") -> pd.DataFrame:\n",
" dfs = []\n",
" for multiplier, result in eval_results.items():\n",
" df = eval_result_as_df(result)\n",
" df['multiplier'] = multiplier\n",
" dfs.append(df)\n",
" return pd.concat(dfs)"
"EXPERIMENT_DIR = pathlib.Path(Environ.ProjectDir) / 'experiments' / f'persona_generalization_{model}'\n",
"print(EXPERIMENT_DIR)\n",
"assert EXPERIMENT_DIR.exists(), f\"Experiment directory {EXPERIMENT_DIR} does not exist\""
]
},
{
Expand Down
Binary file added repepo/paper/figures/llama7b_ood_best_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added repepo/paper/figures/llama7b_ood_worst_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added repepo/paper/figures/qwen_ood_best_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added repepo/paper/figures/qwen_ood_worst_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added repepo/paper/figures/steerability_correlation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
131 changes: 131 additions & 0 deletions repepo/paper/make_figures_steering_ood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# flake8: noqa
# %%
# Setup
import pathlib
import torch
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from steering_vectors import SteeringVector
from repepo.variables import Environ
from repepo.core.evaluate import EvalResult, EvalPrediction
from repepo.experiments.persona_generalization import PersonaCrossSteeringExperimentResult
from repepo.experiments.get_datasets import get_all_prompts
from repepo.paper.utils import (
load_persona_cross_steering_experiment_result,
get_eval_result_sweep,
eval_result_sweep_as_df
)
from repepo.paper.preprocess_results import (
print_dataset_info
)

sns.set_theme()

# %%
model = 'llama7b'

model_full_name = {
'qwen': 'Qwen-1.5-14b-Chat',
'llama7b': 'Llama-2-7b-Chat'
}[model]

# %%
df = pd.read_parquet(f'{model}_steerability.parquet.gzip')
df = df.drop_duplicates()
print_dataset_info(df)

# %%
# Calculate overall steerability by dataset.
# Calculate steerability within each flavour
mean_slope = df.groupby(['dataset_name', 'steering_label', 'dataset_label'])['slope'].mean()
df = df.merge(mean_slope, on=['dataset_name', 'steering_label', 'dataset_label'], suffixes=('', '_mean'))

steerability_id_df = df[
(df.steering_label == 'baseline')
& (df.dataset_label == 'baseline')
& (df.multiplier == 0)
][['dataset_name', 'slope_mean']].drop_duplicates()
# Rename 'slope_mean' to 'steerability'
steerability_id_df = steerability_id_df.rename(columns={'slope_mean': 'steerability'})

steerability_ood_df = df[
(df.steering_label == 'SYS_positive')
& (df.dataset_label == 'SYS_negative')
& (df.multiplier == 0)
][['dataset_name', 'slope_mean']].drop_duplicates()
# Rename 'slope_mean' to 'steerability'
steerability_ood_df = steerability_ood_df.rename(columns={'slope_mean': 'steerability'})

steerability_df = steerability_id_df.merge(steerability_ood_df, on='dataset_name', suffixes=('_id', '_ood'))
steerability_df['worse_ood'] = steerability_df['steerability_ood'] < steerability_df['steerability_id']
steerability_df['label'] = steerability_df['worse_ood'].apply(lambda x: 'OOD < ID' if x else 'OOD > ID')
steerability_df['gap'] = steerability_df['steerability_ood'] - steerability_df['steerability_id']
steerability_df.to_parquet(f'{model}_steerability_summary.parquet.gzip', compression='gzip')
# %%
# Plot the ID vs OOD steerability


sns.regplot(data=steerability_df, x='steerability_id', y='steerability_ood', scatter = False)
sns.scatterplot(data=steerability_df, x='steerability_id', y='steerability_ood', hue = 'label')
sns.lineplot(data=steerability_df, x='steerability_id', y='steerability_id', color='black', linestyle='--')
# for i, row in plot_df.sort_values('gap', ascending = False).tail(3).iterrows():
# plt.text(row['steerability_id'], row['steerability_ood'], row['dataset_name'])
# plt.xlim(-2, 5)
# plt.ylim(-2, 5)
plt.xlabel('ID steerability')
plt.ylabel('OOD steerability')
plt.title(f'{model_full_name} ID vs OOD steerability')
plt.savefig(f'figures/{model}_steerability_id_vs_ood.png')
# %%
# Print the top 5 datasets by gap
steerability_df[['gap', 'dataset_name']].sort_values('gap', ascending = False).head(5)
# %%
# Print the bottom 5 datasets by gap
steerability_df[['gap', 'dataset_name']].sort_values('gap', ascending = True).head(5)

# %%
# Plot the propensity curves for the 3 worst datasets
k = 3
worst_datasets = steerability_df.sort_values('gap', ascending = True).head(k)['dataset_name']

# fig, ax = plt.subplots(1, k, figsize=(15, 5), sharey=True, sharex = True)
fig, ax = plt.subplots()
print(worst_datasets)
for i, dataset_name in enumerate(worst_datasets):
dataset_df = df[
(df.dataset_name == dataset_name)
& (df.steering_label == 'SYS_positive')
& (df.dataset_label == 'SYS_negative')
].drop_duplicates()
print(len(dataset_df))
sns.lineplot(data=dataset_df, x='multiplier', y='logit_diff', ax = ax, label = dataset_name, errorbar='sd')
ax.set_xlabel('Multiplier')
ax.set_ylabel('Propensity')
fig.suptitle(f'{model_full_name} propensity curve for the {k} worst datasets')
fig.tight_layout()
fig.savefig(f'figures/{model}_ood_worst_{k}.png')
plt.show()
# %%

# Plot the propensity curves for the 3 best datasets
k = 3
best_datasets = steerability_df.sort_values('gap', ascending = False).head(k)['dataset_name']

fig, ax = plt.subplots()
print(best_datasets)
for i, dataset_name in enumerate(best_datasets):
dataset_df = df[
(df.dataset_name == dataset_name)
& (df.steering_label == 'SYS_positive')
& (df.dataset_label == 'SYS_negative')
].drop_duplicates()
print(len(dataset_df))
sns.lineplot(data=dataset_df, x='multiplier', y='logit_diff', ax = ax, label = dataset_name, errorbar='sd')
ax.set_xlabel('Multiplier')
ax.set_ylabel('Propensity')
fig.suptitle(f'{model_full_name} propensity curve for the {k} best datasets')
fig.tight_layout()
fig.savefig(f'figures/{model}_ood_best_{k}.png')
plt.show()
# %%
Loading

0 comments on commit 307878f

Please sign in to comment.