Skip to content

Commit

Permalink
add persona generalization experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
dtch1997 committed Nov 22, 2024
1 parent 8dac290 commit 3d70293
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 66 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
**/**/**results/**

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -735,4 +735,4 @@
"desire_capabilities": DESIRE_CAPABILITIES_PROMPTS,
"self_preservation": SELF_PRESERVATION_PROMPTS,
"misc_ai_risk": MISC_AI_RISK_PROMPTS,
}
}
190 changes: 139 additions & 51 deletions experiments/persona_generalization/run_steering_experiment.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,162 @@
""" Script to run a steering experiment and calculate steerability """
""" Script to perform out-of-distribution steering """

import torch
import numpy as np
import pathlib

from transformers import AutoModelForCausalLM, AutoTokenizer
from steering_vectors import train_steering_vector, guess_and_enhance_layer_config
from dataclasses import dataclass
from typing import Literal, Callable
from steering_vectors import train_steering_vector
from steering_bench.build_training_data import build_steering_vector_training_data
from steering_bench.core.evaluate import evaluate_propensities_on_dataset
from steering_bench.utils.torch import load_model_with_quantization, EmptyTorchCUDACache

from steering_bench.dataset import build_dataset, DatasetSpec
from steering_bench.core.format import LlamaChatFormatter
from steering_bench.core.format import Formatter, LlamaChatFormatter
from steering_bench.core.pipeline import Pipeline
from steering_bench.core.propensity import LogProbDifference
from steering_bench.core.hook import SteeringHook
from steering_bench.core.evaluate import evaluate, LogProbDifference, NormalizedPositiveProbability
from steering_bench.metric import get_steerability_slope

from experiments.persona_generalization.persona_prompts import PERSONA_PROMPTS

curr_dir = pathlib.Path(__file__).parent.absolute()
save_dir = curr_dir / "persona_generalization_results"
save_dir.mkdir(exist_ok=True)


@dataclass
class PersonaSpec:
attitude: Literal["positive", "negative", "baseline"]
prompt_strategy: Literal["system", "user", None]

def __str__(self):
if self.prompt_strategy is None:
return f"{self.attitude}"
return f"{self.attitude}_{self.prompt_strategy}"


persona_specs = [
PersonaSpec(attitude="positive", prompt_strategy="system"),
# PersonaSpec(attitude="positive", prompt_strategy="user"),
PersonaSpec(attitude="negative", prompt_strategy="system"),
# PersonaSpec(attitude="negative", prompt_strategy="user"),
PersonaSpec(attitude="baseline", prompt_strategy=None),
]

PersonaPrompt = str


def make_formatter_factory_for_spec(
formatter_cls: type[Formatter], persona_spec: PersonaSpec
) -> Callable[[PersonaPrompt], Formatter]:
if persona_spec.prompt_strategy is None:
return lambda _: formatter_cls()
elif persona_spec.prompt_strategy == "system":
return lambda persona_prompt: formatter_cls(system_prompt=persona_prompt)
elif persona_spec.prompt_strategy == "user":
return lambda persona_prompt: formatter_cls(prompt_prefix=persona_prompt) # type: ignore

raise ValueError(f"Invalid prompt strategy: {persona_spec.prompt_strategy}")


def make_persona_prompt(dataset_name: str, persona_spec: PersonaSpec) -> PersonaPrompt:
if persona_spec.attitude == "positive":
return PERSONA_PROMPTS[dataset_name][0]
elif persona_spec.attitude == "negative":
return PERSONA_PROMPTS[dataset_name][1]
elif persona_spec.attitude == "baseline":
return ""
else:
raise ValueError(f"Invalid attitude: {persona_spec.attitude}")


if __name__ == "__main__":
model_name = "meta-llama/Llama-2-7b-chat-hf"

# Load the dataset
dataset_name = "corrigible-neutral-HHH"
train_spec = DatasetSpec(name=dataset_name, split = "0%:10%", seed = 0)
test_spec = DatasetSpec(name=dataset_name, split = "99%:100%", seed = 0)
train_spec = DatasetSpec(name=dataset_name, split="0%:10%", seed=0)
test_spec = DatasetSpec(name=dataset_name, split="99%:100%", seed=0)
train_dataset = build_dataset(train_spec)
test_dataset = build_dataset(test_spec)
pos_persona_prompt, neg_persona_prompt = PERSONA_PROMPTS[dataset_name]

# Load the model and tokenizer
model_name = "meta-llama/Llama-2-7b-chat-hf"
model, tokenizer = load_model_with_quantization(model_name, load_in_8bit=True)
formatter = LlamaChatFormatter()
pipeline = Pipeline(model=model, tokenizer=tokenizer, formatter=formatter)

training_data = build_steering_vector_training_data(pipeline, train_dataset)
steering_vector = train_steering_vector(
pipeline.model,
pipeline.tokenizer,
training_data,
)

# Now evaluate the SV
evaluators= [
LogProbDifference(),
NormalizedPositiveProbability(),
]

# Train one steering vector for each persona
for train_persona_spec in persona_specs:
formatter_factory = make_formatter_factory_for_spec(
LlamaChatFormatter, train_persona_spec
)
persona_prompt = make_persona_prompt(dataset_name, train_persona_spec)
formatter = formatter_factory(persona_prompt)
pipeline = Pipeline(model=model, tokenizer=tokenizer, formatter=formatter)

sv_save_path = save_dir / f"steering_vector_{train_persona_spec}.pt"
if sv_save_path.exists():
print("Skipping training steering vector")
else:
print("Training steering vector for persona", train_persona_spec)
training_data = build_steering_vector_training_data(pipeline, train_dataset)
steering_vector = train_steering_vector(
pipeline.model,
pipeline.tokenizer,
training_data,
)
torch.save(steering_vector, sv_save_path)

del pipeline

# Evaluate propensity and steerability
layer = 13
multipliers = np.array([-1.5, -1.0, -0.5, 0, 0.5, 1.0, 1.5])
propensities = np.zeros((len(test_dataset), len(multipliers)))

for multiplier_idx, multiplier in enumerate(multipliers):
steering_hook = SteeringHook(
steering_vector,
direction_multiplier=multiplier,
layer = 13,
layer_config = guess_and_enhance_layer_config(pipeline.model),
propensity_score = LogProbDifference()
steerabilities: dict[int, float] = {}

for train_persona_spec in persona_specs:
# Load SV
steering_vector = torch.load(
save_dir / f"steering_vector_{train_persona_spec}.pt"
)
pipeline.hooks.append(steering_hook)
result = evaluate(pipeline, test_dataset, evaluators)
for test_idx, pred in enumerate(result.predictions):
propensities[test_idx, multiplier_idx] = pred.metrics["logprob_diff"]

# calculate steerability
slope = get_steerability_slope(multipliers, propensities)
# Evaluate propensities
for test_persona_spec in persona_specs:

# Load pipeline
formatter_factory = make_formatter_factory_for_spec(
LlamaChatFormatter, test_persona_spec
)
persona_prompt = make_persona_prompt(dataset_name, test_persona_spec)
formatter = formatter_factory(persona_prompt)
pipeline = Pipeline(model=model, tokenizer=tokenizer, formatter=formatter)

import seaborn as sns
import matplotlib.pyplot as plt
sns.set_theme()
propensity_save_path = (
save_dir / f"propensities_{train_persona_spec}_{test_persona_spec}.npy"
)
if propensity_save_path.exists():
continue

# Propensity curve
plt.figure()
plt.plot(multipliers, propensities.T[:, :5])
plt.xlabel("Multiplier")
plt.ylabel("Logprob Difference")
plt.title("Propensity Curve")
# Create the steering hook, which applies the steering vector to the model
steering_hook = SteeringHook(
steering_vector,
direction_multiplier=0.0, # Placeholder value; will be overwritten by evaluate_propensities
layer=layer,
patch_generation_tokens_only=True, # Only patch tokens generated by the model
skip_first_n_generation_tokens=1, # Skip the first token '('
patch_operator="add",
)

# Histplot of the slope
plt.figure()
sns.histplot(slope)
with EmptyTorchCUDACache():
print(f"Running layer {layer}")
pipeline.hooks.clear()
propensities = evaluate_propensities_on_dataset(
pipeline,
steering_hook,
test_dataset,
propensity_fn=propensity_score,
multipliers=multipliers,
)
assert len(pipeline.hooks) == 0

# Save propensities
np.save(propensity_save_path, propensities)
27 changes: 13 additions & 14 deletions steering_bench/core/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ class Formatter(FormatterInterface):
def __init__(
self,
completion_template: str = "{prompt} {response}",
system_prompt: str = "You are a helpful, honest and concise assistant.",
msg_separator: str = "\n",
) -> None:
self.msg_separator = msg_separator
self.system_prompt = system_prompt
self.completion_template = completion_template

@abc.abstractmethod
Expand Down Expand Up @@ -107,8 +109,6 @@ class LlamaChatFormatter(Formatter):
Based on: https://github.com/nrimsky/SycophancySteering/blob/main/utils/tokenize_llama.py#L30
"""

system_prompt: str

B_INST = "[INST]"
E_INST = "[/INST]"
B_SYS = "<<SYS>>\n"
Expand All @@ -121,10 +121,11 @@ def __init__(
prompt_prefix: str | None = None,
system_prompt: str = "You are a helpful, honest and concise assistant.",
) -> None:
self.system_prompt = system_prompt
self.prompt_prefix = prompt_prefix
super().__init__(
completion_template=completion_template, msg_separator=msg_separator
completion_template=completion_template,
msg_separator=msg_separator,
system_prompt=system_prompt,
)

@override
Expand All @@ -150,8 +151,6 @@ class QwenChatFormatter(Formatter):
Wrap conversation using Qwen chat template.
"""

system_prompt: str

B_INST = "<|im_start|>"
E_INST = "<|im_end|>\n"
B_SYS = "system\n"
Expand All @@ -165,10 +164,11 @@ def __init__(
prompt_prefix: str | None = None,
system_prompt: str = "You are a helpful, honest and concise assistant.",
) -> None:
self.system_prompt = system_prompt
self.prompt_prefix = prompt_prefix
super().__init__(
completion_template=completion_template, msg_separator=msg_separator
completion_template=completion_template,
msg_separator=msg_separator,
system_prompt=system_prompt,
)

@override
Expand Down Expand Up @@ -197,24 +197,23 @@ class Llama3ChatFormatter(Formatter):
Wrap conversation using Llama3 chat template.
"""

system_prompt: str

E_INST = "<|eot_id|>"
B_SYS = "<|start_header_id|>system<|end_header_id|>\n\n"
B_USER = "<|start_header_id|>user<|end_header_id|>\n\n"
B_ASST = "<|start_header_id|>assistant<|end_header_id|>\n\n"

def __init__(
self,
completion_template: str = "{prompt}\n\n{response}",
completion_template: str = "{prompt} {response}",
msg_separator: str = "\n",
prompt_prefix: str | None = None,
system_prompt: str = "You are a helpful, honest and concise assistant.",
) -> None:
self.system_prompt = system_prompt
self.prompt_prefix = prompt_prefix
super().__init__(
completion_template=completion_template, msg_separator=msg_separator
completion_template=completion_template,
msg_separator=msg_separator,
system_prompt=system_prompt,
)

@override
Expand Down Expand Up @@ -244,4 +243,4 @@ def format_conversation(
completion = super().format_conversation(current_message, history)
return Completion(
prompt=completion.prompt + "\n\n", response=completion.response
)
)

0 comments on commit 3d70293

Please sign in to comment.