-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adding plotting helper for steerability
- Loading branch information
Showing
1 changed file
with
100 additions
and
0 deletions.
There are no files selected for viewing
100 changes: 100 additions & 0 deletions
100
repepo/experiments/plots/plot_ood_vs_id_persona_generalization_results.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import pandas as pd | ||
from repepo.experiments.persona_generalization import ( | ||
PersonaCrossSteeringExperimentResult, | ||
) | ||
from repepo.steering.evaluate_cross_steering import CrossSteeringResult | ||
import numpy as np | ||
import seaborn as sns | ||
import matplotlib.pyplot as plt | ||
from statistics import mean | ||
from repepo.steering.steerability import ( | ||
get_steerability_slope, | ||
) | ||
from repepo.experiments.persona_prompts import CATEGORIZED_PERSONA_PROMPTS | ||
|
||
|
||
ds_to_category = {} | ||
for category, dataset_prompts in CATEGORIZED_PERSONA_PROMPTS.items(): | ||
for dataset in dataset_prompts.keys(): | ||
ds_to_category[dataset] = category | ||
|
||
|
||
def get_propensities( | ||
cs: CrossSteeringResult, ds_index: int, sv_index: int, metric_name: str | ||
): | ||
return [ | ||
*[ | ||
res[ds_index][sv_index].metrics[metric_name] | ||
for res in cs.neg_steering.values() | ||
], | ||
cs.dataset_baselines[ds_index].metrics[metric_name], | ||
*[ | ||
res[ds_index][sv_index].metrics[metric_name] | ||
for res in cs.pos_steering.values() | ||
], | ||
] | ||
|
||
|
||
def plot_ood_vs_id_persona_generalization_results( | ||
results: dict[str, PersonaCrossSteeringExperimentResult], | ||
title: str = "OOD vs ID steerability", | ||
save_path: str | None = None, | ||
): | ||
rows = [] | ||
for dataset, result in results.items(): | ||
cs = result.cross_steering_result | ||
multipliers = [ | ||
*list(cs.neg_steering.keys()), | ||
0.0, | ||
*list(cs.pos_steering.keys()), | ||
] | ||
baseline_sv_index = cs.steering_labels.index("baseline") | ||
baseline_ds_index = cs.dataset_labels.index("baseline") | ||
id_propensities_ld = get_propensities( | ||
cs, baseline_ds_index, baseline_sv_index, "mean_logit_diff" | ||
) | ||
id_steerability = get_steerability_slope( | ||
np.array(multipliers), np.array([id_propensities_ld]) | ||
)[0] | ||
ood_steerabilities = [] | ||
for ds_index, ds_name in enumerate(cs.dataset_labels): | ||
if ds_name == "baseline" or ds_name == "mean": | ||
continue | ||
propensities_ld = get_propensities( | ||
cs, ds_index, baseline_sv_index, "mean_logit_diff" | ||
) | ||
steerability = get_steerability_slope( | ||
np.array(multipliers), np.array([propensities_ld]) | ||
)[0] | ||
ood_steerabilities.append(steerability) | ||
ood_steerability = mean(ood_steerabilities) | ||
rows.append( | ||
{ | ||
"dataset": dataset, | ||
"id_steerability": id_steerability, | ||
"ood_steerability": ood_steerability, | ||
} | ||
) | ||
df = pd.DataFrame.from_records(rows) | ||
|
||
sns.set_theme(style="darkgrid") | ||
ax = sns.scatterplot( | ||
data=df, | ||
x="id_steerability", | ||
y="ood_steerability", | ||
) | ||
ax.set_title(title) | ||
ax.set( | ||
xlabel="In-distribution steerability", ylabel="Out-of-distribution steerability" | ||
) | ||
|
||
maxes = [df["id_steerability"].max(), df["ood_steerability"].max()] | ||
mins = [df["id_steerability"].min(), df["ood_steerability"].min()] | ||
|
||
ax.set_xlim(int(min(mins)) - 0.25, round(max(maxes)) + 0.25) | ||
ax.set_ylim(int(min(mins)) - 0.25, round(max(maxes)) + 0.25) | ||
|
||
if save_path is not None: | ||
plt.savefig(save_path) | ||
|
||
plt.show() |