From ffe466652715b893b356bcc99218bd3e46ed085c Mon Sep 17 00:00:00 2001 From: Daniel CH Tan Date: Fri, 22 Nov 2024 21:12:30 +0000 Subject: [PATCH] refactor --- .../persona_generalization/persona_prompts.py | 51 +++++++++++++++++ .../run_steering_experiment.py | 55 ++----------------- 2 files changed, 55 insertions(+), 51 deletions(-) diff --git a/experiments/persona_generalization/persona_prompts.py b/experiments/persona_generalization/persona_prompts.py index a1b3599..b03fb7c 100644 --- a/experiments/persona_generalization/persona_prompts.py +++ b/experiments/persona_generalization/persona_prompts.py @@ -1,5 +1,9 @@ """ Script containing all persona prompts for each evaluation. """ +from dataclasses import dataclass +from typing import Callable, Literal +from steering_bench.core.format import Formatter + POLITICS_PROMPTS: dict[str, tuple[str, str]] = { "politically-liberal": ( "You are politically liberal.", @@ -736,3 +740,50 @@ "self_preservation": SELF_PRESERVATION_PROMPTS, "misc_ai_risk": MISC_AI_RISK_PROMPTS, } + + +@dataclass +class PersonaSpec: + attitude: Literal["positive", "negative", "baseline"] + prompt_strategy: Literal["system", "user", None] + + def __str__(self): + if self.prompt_strategy is None: + return f"{self.attitude}" + return f"{self.attitude}_{self.prompt_strategy}" + + +PersonaPrompt = str + + +def _make_formatter_factory_for_spec( + formatter_cls: type[Formatter], persona_spec: PersonaSpec +) -> Callable[[PersonaPrompt], Formatter]: + if persona_spec.prompt_strategy is None: + return lambda _: formatter_cls() + elif persona_spec.prompt_strategy == "system": + return lambda persona_prompt: formatter_cls(system_message=persona_prompt) + elif persona_spec.prompt_strategy == "user": + return lambda persona_prompt: formatter_cls(user_message=persona_prompt) + + raise ValueError(f"Invalid prompt strategy: {persona_spec.prompt_strategy}") + + +def _make_persona_prompt(dataset_name: str, persona_spec: PersonaSpec) -> PersonaPrompt: + if persona_spec.attitude == "positive": + return PERSONA_PROMPTS[dataset_name][0] + elif persona_spec.attitude == "negative": + return PERSONA_PROMPTS[dataset_name][1] + elif persona_spec.attitude == "baseline": + return "" + else: + raise ValueError(f"Invalid attitude: {persona_spec.attitude}") + + +def make_formatter_for_persona( + dataset_name: str, + persona_spec: PersonaSpec, +): + formatter_factory = _make_formatter_factory_for_spec(Formatter, persona_spec) + persona_prompt = _make_persona_prompt(dataset_name, persona_spec) + return formatter_factory(persona_prompt) diff --git a/experiments/persona_generalization/run_steering_experiment.py b/experiments/persona_generalization/run_steering_experiment.py index b0fbd77..25b6ab6 100644 --- a/experiments/persona_generalization/run_steering_experiment.py +++ b/experiments/persona_generalization/run_steering_experiment.py @@ -4,36 +4,25 @@ import numpy as np import pathlib -from dataclasses import dataclass -from typing import Literal, Callable from steering_vectors import train_steering_vector from steering_bench.build_training_data import build_steering_vector_training_data from steering_bench.core.evaluate import evaluate_propensities_on_dataset from steering_bench.utils.torch import load_model_with_quantization, EmptyTorchCUDACache from steering_bench.dataset import build_dataset, DatasetSpec -from steering_bench.core.format import Formatter from steering_bench.core.pipeline import Pipeline from steering_bench.core.propensity import LogProbDifference from steering_bench.core.hook import SteeringHook -from experiments.persona_generalization.persona_prompts import PERSONA_PROMPTS +from experiments.persona_generalization.persona_prompts import ( + PersonaSpec, + make_formatter_for_persona, +) curr_dir = pathlib.Path(__file__).parent.absolute() save_dir = curr_dir / "persona_generalization_results" save_dir.mkdir(exist_ok=True) -@dataclass -class PersonaSpec: - attitude: Literal["positive", "negative", "baseline"] - prompt_strategy: Literal["system", "user", None] - - def __str__(self): - if self.prompt_strategy is None: - return f"{self.attitude}" - return f"{self.attitude}_{self.prompt_strategy}" - - persona_specs = [ PersonaSpec(attitude="positive", prompt_strategy="system"), # PersonaSpec(attitude="positive", prompt_strategy="user"), @@ -42,41 +31,6 @@ def __str__(self): PersonaSpec(attitude="baseline", prompt_strategy=None), ] -PersonaPrompt = str - - -def _make_formatter_factory_for_spec( - formatter_cls: type[Formatter], persona_spec: PersonaSpec -) -> Callable[[PersonaPrompt], Formatter]: - if persona_spec.prompt_strategy is None: - return lambda _: formatter_cls() - elif persona_spec.prompt_strategy == "system": - return lambda persona_prompt: formatter_cls(system_message=persona_prompt) - elif persona_spec.prompt_strategy == "user": - return lambda persona_prompt: formatter_cls(user_message=persona_prompt) - - raise ValueError(f"Invalid prompt strategy: {persona_spec.prompt_strategy}") - - -def _make_persona_prompt(dataset_name: str, persona_spec: PersonaSpec) -> PersonaPrompt: - if persona_spec.attitude == "positive": - return PERSONA_PROMPTS[dataset_name][0] - elif persona_spec.attitude == "negative": - return PERSONA_PROMPTS[dataset_name][1] - elif persona_spec.attitude == "baseline": - return "" - else: - raise ValueError(f"Invalid attitude: {persona_spec.attitude}") - - -def make_formatter_for_persona( - dataset_name: str, - persona_spec: PersonaSpec, -): - formatter_factory = _make_formatter_factory_for_spec(Formatter, persona_spec) - persona_prompt = _make_persona_prompt(dataset_name, persona_spec) - return formatter_factory(persona_prompt) - if __name__ == "__main__": @@ -86,7 +40,6 @@ def make_formatter_for_persona( test_spec = DatasetSpec(name=dataset_name, split="99%:100%", seed=0) train_dataset = build_dataset(train_spec) test_dataset = build_dataset(test_spec) - pos_persona_prompt, neg_persona_prompt = PERSONA_PROMPTS[dataset_name] # Load the model and tokenizer model_name = "meta-llama/Llama-2-7b-chat-hf"