Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
dtch1997 committed Nov 22, 2024
1 parent 53021b6 commit ffe4666
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 51 deletions.
51 changes: 51 additions & 0 deletions experiments/persona_generalization/persona_prompts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
""" Script containing all persona prompts for each evaluation. """

from dataclasses import dataclass
from typing import Callable, Literal
from steering_bench.core.format import Formatter

POLITICS_PROMPTS: dict[str, tuple[str, str]] = {
"politically-liberal": (
"You are politically liberal.",
Expand Down Expand Up @@ -736,3 +740,50 @@
"self_preservation": SELF_PRESERVATION_PROMPTS,
"misc_ai_risk": MISC_AI_RISK_PROMPTS,
}


@dataclass
class PersonaSpec:
attitude: Literal["positive", "negative", "baseline"]
prompt_strategy: Literal["system", "user", None]

def __str__(self):
if self.prompt_strategy is None:
return f"{self.attitude}"
return f"{self.attitude}_{self.prompt_strategy}"


PersonaPrompt = str


def _make_formatter_factory_for_spec(
formatter_cls: type[Formatter], persona_spec: PersonaSpec
) -> Callable[[PersonaPrompt], Formatter]:
if persona_spec.prompt_strategy is None:
return lambda _: formatter_cls()
elif persona_spec.prompt_strategy == "system":
return lambda persona_prompt: formatter_cls(system_message=persona_prompt)
elif persona_spec.prompt_strategy == "user":
return lambda persona_prompt: formatter_cls(user_message=persona_prompt)

raise ValueError(f"Invalid prompt strategy: {persona_spec.prompt_strategy}")


def _make_persona_prompt(dataset_name: str, persona_spec: PersonaSpec) -> PersonaPrompt:
if persona_spec.attitude == "positive":
return PERSONA_PROMPTS[dataset_name][0]
elif persona_spec.attitude == "negative":
return PERSONA_PROMPTS[dataset_name][1]
elif persona_spec.attitude == "baseline":
return ""
else:
raise ValueError(f"Invalid attitude: {persona_spec.attitude}")


def make_formatter_for_persona(
dataset_name: str,
persona_spec: PersonaSpec,
):
formatter_factory = _make_formatter_factory_for_spec(Formatter, persona_spec)
persona_prompt = _make_persona_prompt(dataset_name, persona_spec)
return formatter_factory(persona_prompt)
55 changes: 4 additions & 51 deletions experiments/persona_generalization/run_steering_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,25 @@
import numpy as np
import pathlib

from dataclasses import dataclass
from typing import Literal, Callable
from steering_vectors import train_steering_vector
from steering_bench.build_training_data import build_steering_vector_training_data
from steering_bench.core.evaluate import evaluate_propensities_on_dataset
from steering_bench.utils.torch import load_model_with_quantization, EmptyTorchCUDACache
from steering_bench.dataset import build_dataset, DatasetSpec
from steering_bench.core.format import Formatter
from steering_bench.core.pipeline import Pipeline
from steering_bench.core.propensity import LogProbDifference
from steering_bench.core.hook import SteeringHook

from experiments.persona_generalization.persona_prompts import PERSONA_PROMPTS
from experiments.persona_generalization.persona_prompts import (
PersonaSpec,
make_formatter_for_persona,
)

curr_dir = pathlib.Path(__file__).parent.absolute()
save_dir = curr_dir / "persona_generalization_results"
save_dir.mkdir(exist_ok=True)


@dataclass
class PersonaSpec:
attitude: Literal["positive", "negative", "baseline"]
prompt_strategy: Literal["system", "user", None]

def __str__(self):
if self.prompt_strategy is None:
return f"{self.attitude}"
return f"{self.attitude}_{self.prompt_strategy}"


persona_specs = [
PersonaSpec(attitude="positive", prompt_strategy="system"),
# PersonaSpec(attitude="positive", prompt_strategy="user"),
Expand All @@ -42,41 +31,6 @@ def __str__(self):
PersonaSpec(attitude="baseline", prompt_strategy=None),
]

PersonaPrompt = str


def _make_formatter_factory_for_spec(
formatter_cls: type[Formatter], persona_spec: PersonaSpec
) -> Callable[[PersonaPrompt], Formatter]:
if persona_spec.prompt_strategy is None:
return lambda _: formatter_cls()
elif persona_spec.prompt_strategy == "system":
return lambda persona_prompt: formatter_cls(system_message=persona_prompt)
elif persona_spec.prompt_strategy == "user":
return lambda persona_prompt: formatter_cls(user_message=persona_prompt)

raise ValueError(f"Invalid prompt strategy: {persona_spec.prompt_strategy}")


def _make_persona_prompt(dataset_name: str, persona_spec: PersonaSpec) -> PersonaPrompt:
if persona_spec.attitude == "positive":
return PERSONA_PROMPTS[dataset_name][0]
elif persona_spec.attitude == "negative":
return PERSONA_PROMPTS[dataset_name][1]
elif persona_spec.attitude == "baseline":
return ""
else:
raise ValueError(f"Invalid attitude: {persona_spec.attitude}")


def make_formatter_for_persona(
dataset_name: str,
persona_spec: PersonaSpec,
):
formatter_factory = _make_formatter_factory_for_spec(Formatter, persona_spec)
persona_prompt = _make_persona_prompt(dataset_name, persona_spec)
return formatter_factory(persona_prompt)


if __name__ == "__main__":

Expand All @@ -86,7 +40,6 @@ def make_formatter_for_persona(
test_spec = DatasetSpec(name=dataset_name, split="99%:100%", seed=0)
train_dataset = build_dataset(train_spec)
test_dataset = build_dataset(test_spec)
pos_persona_prompt, neg_persona_prompt = PERSONA_PROMPTS[dataset_name]

# Load the model and tokenizer
model_name = "meta-llama/Llama-2-7b-chat-hf"
Expand Down

0 comments on commit ffe4666

Please sign in to comment.