-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'origin/far-cluster' into far-cluster
- Loading branch information
Showing
10 changed files
with
159 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from git.repo import Repo | ||
import shlex | ||
from pathlib import Path | ||
from repepo.experiments.get_datasets import get_all_prompts | ||
|
||
with open(Path(__file__).parent / 'runner.yaml') as f: | ||
template = f.read() | ||
|
||
repo = Repo(".") | ||
commit_hash = str(repo.head.object.hexsha) | ||
layer = 30 # found to be optimal in layer sweep | ||
|
||
prompts = get_all_prompts() | ||
|
||
for dataset_idx in range(len(get_all_prompts())): | ||
dataset_name = list(prompts.keys())[dataset_idx] | ||
command = [ | ||
"python", | ||
"repepo/experiments/persona_generalization.py", | ||
"--model_name=meta-llama/Meta-Llama-3.1-70B-Instruct", | ||
"--formatter_name qwen-chat-formatter", | ||
f"--layer={layer}", | ||
"--output_dir=/training/persona_generalization_llama3_70b", | ||
# Hacky way to pass the dataset, for legacy reasons, sorry | ||
f"--sge_task_id={dataset_idx}" | ||
] | ||
print(template.format( | ||
COMMAND=shlex.join(command), | ||
NAME=f"persona_generalization-{dataset_name}", | ||
IMAGE='ghcr.io/alignmentresearch/repepo:a26aee0-main', | ||
COMMIT_HASH=commit_hash, | ||
PRIORITY='normal-batch', | ||
CPU="4", | ||
MEMORY="200Gi", | ||
GPU="2", | ||
USER_ID=1001, | ||
GROUP_ID=1001, | ||
OMP_NUM_THREADS="\"4\"", | ||
TRAINING_MOUNT="/training")) | ||
print("---") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import torch | ||
import seaborn as sns | ||
import pathlib | ||
|
||
from steering_vectors import SteeringVector | ||
from repepo.variables import Environ | ||
from repepo.core.evaluate import EvalResult | ||
from repepo.steering.plots.plot_sweep_layers_result import plot_sweep_layers_result, SweepLayersResult | ||
|
||
sns.set_theme(style="darkgrid") | ||
|
||
DatasetName = str | ||
Layer = int | ||
|
||
results_dir = pathlib.Path(Environ.ProjectDir) / "experiments" / "layer_sweep_llama3_70b" | ||
assert results_dir.exists(), f"Results directory not found: {results_dir}" | ||
|
||
current_dir = pathlib.Path(__file__).parent | ||
figures_dir = (current_dir / "figures").absolute() | ||
|
||
SWEEP_DATASETS = [ | ||
"anti-immigration", | ||
"believes-abortion-should-be-illegal", | ||
"conscientiousness", | ||
"desire-for-acquiring-compute", | ||
"risk-seeking", | ||
"openness", | ||
"self-replication", | ||
"very-small-harm-justifies-very-large-benefit", | ||
"corrigible-neutral-HHH", | ||
"myopic-reward", | ||
"power-seeking-inclination", | ||
] | ||
|
||
# NOTE: The data was saved in a different format, so we need to manually rehydrate it into a SweepLayersResult | ||
def load_sweep_layers_result() -> SweepLayersResult: | ||
|
||
multipliers: list[float] = [-1.5, -1.0, -0.5, 0.5, 1.0, 1.5] | ||
layers: list[int] = list(range(80)) | ||
|
||
def load_steering_vectors () -> dict[DatasetName, dict[Layer, SteeringVector]]: | ||
steering_vectors: dict[DatasetName, dict[Layer, SteeringVector]] = {} | ||
for dataset in SWEEP_DATASETS: | ||
ds_vectors = {} | ||
for layer in range(80): | ||
try: | ||
steering_vector: SteeringVector = torch.load(results_dir / f'sv_{dataset}_{layer}.pt', map_location = 'cpu') | ||
except Exception as e: | ||
print(f'Error loading sv_{dataset}_{layer}.pt: {e}') | ||
continue | ||
ds_vectors[layer] = steering_vector | ||
steering_vectors[dataset] = ds_vectors | ||
return steering_vectors | ||
|
||
def load_steering_results() -> dict[DatasetName, dict[Layer, list[EvalResult]]]: | ||
steering_results: dict[DatasetName, dict[Layer, list[EvalResult]]] = {} | ||
for dataset in SWEEP_DATASETS: | ||
ds_results = {} | ||
for layer in range(80): | ||
try: | ||
eval_results: list[EvalResult] = torch.load(results_dir / f'multiplier_res_{dataset}_{layer}.pt', map_location = 'cpu') | ||
except Exception as e: | ||
print(f'Error loading multiplier_res_{dataset}_{layer}.pt: {e}') | ||
continue | ||
ds_results[layer] = eval_results | ||
steering_results[dataset] = ds_results | ||
return steering_results | ||
|
||
return SweepLayersResult( | ||
steering_vectors=load_steering_vectors(), | ||
multipliers=multipliers, | ||
layers=layers, | ||
steering_results=load_steering_results() | ||
) | ||
|
||
results = load_sweep_layers_result() | ||
df = plot_sweep_layers_result(results, save_path = str(figures_dir / "llama3.1_70b_sweep.png")) | ||
df = plot_sweep_layers_result(results, save_path = str(figures_dir / "llama3.1_70b_sweep.pdf")) | ||
|
||
# Plot with plotly to allow for interactive exploration | ||
import plotly.express as px | ||
px.line(df, x="Layer", y="Steerability", color="Dataset") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters