Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/far-cluster' into far-cluster
Browse files Browse the repository at this point in the history
  • Loading branch information
rhaps0dy committed Oct 10, 2024
2 parents 4015193 + 8187598 commit 1cacf14
Show file tree
Hide file tree
Showing 10 changed files with 159 additions and 4 deletions.
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ DEVBOX_UID ?= 1001
CPU ?= 1
MEMORY ?= 60G
GPU ?= 0
DEVBOX_NAME ?= ${IMAGE_NAME}-devbox
DEVBOX_NAME ?= ${IMAGE_NAME}-devbox-2

default: help

Expand Down Expand Up @@ -120,6 +120,9 @@ devbox/%:
devbox: devbox/main
true

clean-devbox:
kubectl delete job "${DEVBOX_NAME}" || true

.PHONY: cuda-devbox cuda-devbox/%
cuda-devbox/%: devbox/%
true # Do nothing, the body has to have something otherwise make complains
Expand Down
40 changes: 40 additions & 0 deletions cluster/launch_llama3_persona_generalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from git.repo import Repo
import shlex
from pathlib import Path
from repepo.experiments.get_datasets import get_all_prompts

with open(Path(__file__).parent / 'runner.yaml') as f:
template = f.read()

repo = Repo(".")
commit_hash = str(repo.head.object.hexsha)
layer = 30 # found to be optimal in layer sweep

prompts = get_all_prompts()

for dataset_idx in range(len(get_all_prompts())):
dataset_name = list(prompts.keys())[dataset_idx]
command = [
"python",
"repepo/experiments/persona_generalization.py",
"--model_name=meta-llama/Meta-Llama-3.1-70B-Instruct",
"--formatter_name qwen-chat-formatter",
f"--layer={layer}",
"--output_dir=/training/persona_generalization_llama3_70b",
# Hacky way to pass the dataset, for legacy reasons, sorry
f"--sge_task_id={dataset_idx}"
]
print(template.format(
COMMAND=shlex.join(command),
NAME=f"persona_generalization-{dataset_name}",
IMAGE='ghcr.io/alignmentresearch/repepo:a26aee0-main',
COMMIT_HASH=commit_hash,
PRIORITY='normal-batch',
CPU="4",
MEMORY="200Gi",
GPU="2",
USER_ID=1001,
GROUP_ID=1001,
OMP_NUM_THREADS="\"4\"",
TRAINING_MOUNT="/training"))
print("---")
2 changes: 1 addition & 1 deletion cluster/runner.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ spec:
volumes:
- name: training
persistentVolumeClaim:
claimName: repepo
claimName: az-repepo
- name: hf-cache
persistentVolumeClaim:
claimName: repepo-local
Expand Down
26 changes: 25 additions & 1 deletion pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dependencies = [
"statsmodels>=0.14.1",
"jaxtyping>=0.2.28",
"concept-erasure>=0.2.4",
"plotly>=5.24.1",
]

[tool.black]
Expand Down
Binary file added repepo/paper/figures/llama3.1_70b_sweep.pdf
Binary file not shown.
Binary file added repepo/paper/figures/llama3.1_70b_sweep.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
82 changes: 82 additions & 0 deletions repepo/paper/plot_llama3_layer_sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
import seaborn as sns
import pathlib

from steering_vectors import SteeringVector
from repepo.variables import Environ
from repepo.core.evaluate import EvalResult
from repepo.steering.plots.plot_sweep_layers_result import plot_sweep_layers_result, SweepLayersResult

sns.set_theme(style="darkgrid")

DatasetName = str
Layer = int

results_dir = pathlib.Path(Environ.ProjectDir) / "experiments" / "layer_sweep_llama3_70b"
assert results_dir.exists(), f"Results directory not found: {results_dir}"

current_dir = pathlib.Path(__file__).parent
figures_dir = (current_dir / "figures").absolute()

SWEEP_DATASETS = [
"anti-immigration",
"believes-abortion-should-be-illegal",
"conscientiousness",
"desire-for-acquiring-compute",
"risk-seeking",
"openness",
"self-replication",
"very-small-harm-justifies-very-large-benefit",
"corrigible-neutral-HHH",
"myopic-reward",
"power-seeking-inclination",
]

# NOTE: The data was saved in a different format, so we need to manually rehydrate it into a SweepLayersResult
def load_sweep_layers_result() -> SweepLayersResult:

multipliers: list[float] = [-1.5, -1.0, -0.5, 0.5, 1.0, 1.5]
layers: list[int] = list(range(80))

def load_steering_vectors () -> dict[DatasetName, dict[Layer, SteeringVector]]:
steering_vectors: dict[DatasetName, dict[Layer, SteeringVector]] = {}
for dataset in SWEEP_DATASETS:
ds_vectors = {}
for layer in range(80):
try:
steering_vector: SteeringVector = torch.load(results_dir / f'sv_{dataset}_{layer}.pt', map_location = 'cpu')
except Exception as e:
print(f'Error loading sv_{dataset}_{layer}.pt: {e}')
continue
ds_vectors[layer] = steering_vector
steering_vectors[dataset] = ds_vectors
return steering_vectors

def load_steering_results() -> dict[DatasetName, dict[Layer, list[EvalResult]]]:
steering_results: dict[DatasetName, dict[Layer, list[EvalResult]]] = {}
for dataset in SWEEP_DATASETS:
ds_results = {}
for layer in range(80):
try:
eval_results: list[EvalResult] = torch.load(results_dir / f'multiplier_res_{dataset}_{layer}.pt', map_location = 'cpu')
except Exception as e:
print(f'Error loading multiplier_res_{dataset}_{layer}.pt: {e}')
continue
ds_results[layer] = eval_results
steering_results[dataset] = ds_results
return steering_results

return SweepLayersResult(
steering_vectors=load_steering_vectors(),
multipliers=multipliers,
layers=layers,
steering_results=load_steering_results()
)

results = load_sweep_layers_result()
df = plot_sweep_layers_result(results, save_path = str(figures_dir / "llama3.1_70b_sweep.png"))
df = plot_sweep_layers_result(results, save_path = str(figures_dir / "llama3.1_70b_sweep.pdf"))

# Plot with plotly to allow for interactive exploration
import plotly.express as px
px.line(df, x="Layer", y="Steerability", color="Dataset")
5 changes: 5 additions & 0 deletions repepo/steering/plots/plot_sweep_layers_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,14 @@ def plot_sweep_layers_result(

plt.title(title)
plt.legend(title="Dataset", loc="lower left", fontsize="small")
# Move the legend outside the plot
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()

if save_path is not None:
plt.savefig(save_path)

# Show the plot
plt.show()

return df
2 changes: 1 addition & 1 deletion repepo/steering/sweep_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

@dataclass
class SweepLayersResult:
steering_vectors: dict[str, dict[int, SteeringVector]]
steering_vectors: dict[str, dict[int, SteeringVector]] # dataset -> layer -> steering vector
multipliers: list[float]
layers: list[int]
steering_results: dict[str, dict[int, list[EvalResult]]]
Expand Down

0 comments on commit 1cacf14

Please sign in to comment.