From 8d1bd7da1160df267f0425fc4e0d7ec039a7c2be Mon Sep 17 00:00:00 2001 From: Daniel Tan <25474937+dtch1997@users.noreply.github.com> Date: Thu, 7 Mar 2024 12:05:04 +0000 Subject: [PATCH] feat: steering experiments (#132) * Add experimental code * fix: make_country_capital script * feat: add code to run steering experiment * update experiments code * fix: add --config_path arg * fix: config yaml parsing * chore: add more configs * chore: add even more configs * refactor: plotting * feat: add script to run sweep * fix: do not set completion template by default * refactor sweeps * refactor: token concept sweep * fix: bugbears * chore: add comments * fix: steering_index of datasets * test: steering token index * updating steering_vectors library version * evaluate on more layers * refactor: use steering-vectors code, log instead of print * chore: fix docstring * test: training, evaluating steering vectors * fix: minor --------- Co-authored-by: Daniel CH Tan Co-authored-by: David Chanin --- pdm.lock | 9 +- pyproject.toml | 2 +- repepo/core/evaluate.py | 10 +- repepo/data/multiple_choice/make_bats.py | 2 +- .../multiple_choice/make_caa_truthfulqa.py | 1 + .../make_country_capital_with_prompt.py | 33 +++ .../data/multiple_choice/make_mwe_persona.py | 2 +- repepo/data/multiple_choice/make_mwe_xrisk.py | 1 + .../believes-abortion-should-be-illegal.yaml | 15 + ...desire-for-recursive-self-improvement.yaml | 15 + .../experiments/configs/machiavellianism.yaml | 15 + repepo/experiments/configs/sycophancy.yaml | 15 + repepo/experiments/configs/truthfulqa.yaml | 15 + ...-HHH-to-be-deployed-in-the-real-world.yaml | 15 + repepo/experiments/run_sweep.py | 27 ++ .../experiments/sweeps/persona_and_token.sh | 3 + .../sweeps/run_sweep_persona_concepts.py | 35 +++ .../sweeps/run_sweep_token_concepts.py | 34 +++ .../steering/build_steering_training_data.py | 41 +++ repepo/steering/evaluate_steering_vector.py | 82 ++++++ repepo/steering/get_aggregator.py | 17 ++ repepo/steering/plot_results_by_layer.py | 26 ++ repepo/steering/run_experiment.py | 138 +++++++++ repepo/steering/utils/helpers.py | 275 ++++++++++++++++++ repepo/steering/utils/variables.py | 9 + tests/conftest.py | 7 +- .../test_make_caa_truthfulqa.py | 5 + .../multiple_choice/test_make_mwe_persona.py | 3 + .../multiple_choice/test_make_mwe_xrisk.py | 2 + tests/steering/test_run_experiment.py | 171 +++++++++++ 30 files changed, 1011 insertions(+), 14 deletions(-) create mode 100644 repepo/data/multiple_choice/make_country_capital_with_prompt.py create mode 100644 repepo/experiments/configs/believes-abortion-should-be-illegal.yaml create mode 100644 repepo/experiments/configs/desire-for-recursive-self-improvement.yaml create mode 100644 repepo/experiments/configs/machiavellianism.yaml create mode 100644 repepo/experiments/configs/sycophancy.yaml create mode 100644 repepo/experiments/configs/truthfulqa.yaml create mode 100644 repepo/experiments/configs/willingness-to-be-non-HHH-to-be-deployed-in-the-real-world.yaml create mode 100644 repepo/experiments/run_sweep.py create mode 100644 repepo/experiments/sweeps/persona_and_token.sh create mode 100644 repepo/experiments/sweeps/run_sweep_persona_concepts.py create mode 100644 repepo/experiments/sweeps/run_sweep_token_concepts.py create mode 100644 repepo/steering/build_steering_training_data.py create mode 100644 repepo/steering/evaluate_steering_vector.py create mode 100644 repepo/steering/get_aggregator.py create mode 100644 repepo/steering/plot_results_by_layer.py create mode 100644 repepo/steering/run_experiment.py create mode 100644 repepo/steering/utils/helpers.py create mode 100644 repepo/steering/utils/variables.py create mode 100644 tests/steering/test_run_experiment.py diff --git a/pdm.lock b/pdm.lock index a2a9e56e..caf5fcf3 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["cross_platform"] lock_version = "4.4" -content_hash = "sha256:d7949521dc64fb6e0ca1b741ff6f724da1b5a4fe1f4d609f1b5a5473a3067e00" +content_hash = "sha256:4a178c5c8c7c266cc0414ddbd1b7fada95c9d19478e5690062484fde06980909" [[package]] name = "absl-py" @@ -4064,16 +4064,17 @@ files = [ [[package]] name = "steering-vectors" -version = "0.5.0" +version = "0.10.0" requires_python = ">=3.10,<4.0" summary = "Steering vectors for transformer language models in Pytorch / Huggingface" dependencies = [ + "scikit-learn<2.0.0,>=1.4.0", "tqdm<5.0.0,>=4.1.0", "transformers<5.0.0,>=4.35.2", ] files = [ - {file = "steering_vectors-0.5.0-py3-none-any.whl", hash = "sha256:a5358318970b4a41ecba3f0b3e807c85552f019d621746836b5e0a4b10da1dd2"}, - {file = "steering_vectors-0.5.0.tar.gz", hash = "sha256:a799b158deba0753f761740861661b1fe1e58a8aac3fc5bdec44b6fe750b184d"}, + {file = "steering_vectors-0.10.0-py3-none-any.whl", hash = "sha256:d9a66f8b96449e230c58151d9ff7452f81a88e824a87961c51ace7020cc71ba1"}, + {file = "steering_vectors-0.10.0.tar.gz", hash = "sha256:70cca5bfddebb6e4768e9e5df4c16ff901c1f0f02472344449f89c8d42ba4f19"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index 2b708034..872e963a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "termcolor>=2.4.0", "bitsandbytes>=0.42.0", "nbdime>=4.0.1", - "steering-vectors>=0.3.0", + "steering-vectors>=0.10.0", "openai>=1.10.0", "arrr>=1.0.4", "spacy>=3.7.2", diff --git a/repepo/core/evaluate.py b/repepo/core/evaluate.py index 659cdeb1..12706f0f 100644 --- a/repepo/core/evaluate.py +++ b/repepo/core/evaluate.py @@ -13,6 +13,7 @@ from repepo.core.pipeline import Pipeline import numpy as np +import logging # eval_hooks allow us to do custom stuff to the pipeline only during evaluation EvalHook = Callable[[Pipeline], AbstractContextManager[None]] @@ -205,7 +206,7 @@ def evaluate( eval_hooks: Sequence[EvalHook] = [], show_progress: bool = True, tqdm_desc: str = "Evaluating", - verbose: bool = False, + logger: logging.Logger | None = None, ) -> EvalResult: # evaluate predictions: list[EvalPrediction] = [] @@ -217,9 +218,10 @@ def evaluate( for i, example in enumerate( tqdm(dataset, disable=not show_progress, desc=tqdm_desc) ): - if i == 0 and verbose: - print("Example full prompt:") - print(pipeline.build_full_prompt(example.positive)) + if logger is not None and i == 0: + logger.info( + f"Example full prompt: \n {pipeline.build_full_prompt(example.positive)}" + ) positive_probs = pipeline.calculate_output_logprobs(example.positive) negative_probs = pipeline.calculate_output_logprobs(example.negative) diff --git a/repepo/data/multiple_choice/make_bats.py b/repepo/data/multiple_choice/make_bats.py index 644103e5..4c8e4b1f 100644 --- a/repepo/data/multiple_choice/make_bats.py +++ b/repepo/data/multiple_choice/make_bats.py @@ -30,7 +30,7 @@ def convert_bats_dataset( def make_bats(): - """Make MWE dataset""" + """Make BATS dataset""" for dataset_path in get_raw_dataset_dir().glob("bats/*/*.txt"): with open(dataset_path, "r") as file: list_dataset = [] diff --git a/repepo/data/multiple_choice/make_caa_truthfulqa.py b/repepo/data/multiple_choice/make_caa_truthfulqa.py index 7f9a2518..148ea87e 100644 --- a/repepo/data/multiple_choice/make_caa_truthfulqa.py +++ b/repepo/data/multiple_choice/make_caa_truthfulqa.py @@ -49,6 +49,7 @@ def convert_hf_truthfulqa_caa_dataset( Example( positive=Completion(prompt=prompt, response=correct_ans), negative=Completion(prompt=prompt, response=incorrect_ans), + steering_token_index=-2, ) ) return tqa_dataset diff --git a/repepo/data/multiple_choice/make_country_capital_with_prompt.py b/repepo/data/multiple_choice/make_country_capital_with_prompt.py new file mode 100644 index 00000000..da86bf93 --- /dev/null +++ b/repepo/data/multiple_choice/make_country_capital_with_prompt.py @@ -0,0 +1,33 @@ +from repepo.core.types import Dataset +from repepo.data.multiple_choice.make_bats import ( + convert_bats_dataset, + get_raw_dataset_dir, + build_dataset_filename, + get_dataset_dir, + jdump, +) + +PROMPT = "I live in" + + +def make_country_capital_with_prompt(): + dataset_path = get_raw_dataset_dir() / ( + "bats/3_Encyclopedic_semantics/E01 [country - capital].txt" + ) + assert dataset_path.exists() + + with open(dataset_path, "r") as file: + list_dataset = [] + for line in file: + element = line.strip().split() + assert len(element) == 2 + list_dataset.append(tuple(element)) + + dataset_name = "country-capital-with-prompt" + filename = build_dataset_filename(dataset_name) + mwe_dataset: Dataset = convert_bats_dataset(list_dataset, prompt=PROMPT) + jdump(mwe_dataset, get_dataset_dir() / "bats" / filename) + + +if __name__ == "__main__": + make_country_capital_with_prompt() diff --git a/repepo/data/multiple_choice/make_mwe_persona.py b/repepo/data/multiple_choice/make_mwe_persona.py index e29d30b1..18da69f0 100644 --- a/repepo/data/multiple_choice/make_mwe_persona.py +++ b/repepo/data/multiple_choice/make_mwe_persona.py @@ -93,7 +93,7 @@ def convert_mwe_personas_dataset_caa( response=element["answer_not_matching_behavior"], ) - ex = Example(positive=positive, negative=negative) + ex = Example(positive=positive, negative=negative, steering_token_index=-2) mwe_dataset.append(ex) return mwe_dataset diff --git a/repepo/data/multiple_choice/make_mwe_xrisk.py b/repepo/data/multiple_choice/make_mwe_xrisk.py index 387881c3..fc5c65e6 100644 --- a/repepo/data/multiple_choice/make_mwe_xrisk.py +++ b/repepo/data/multiple_choice/make_mwe_xrisk.py @@ -40,6 +40,7 @@ def convert_mwe_dataset( Example( positive=positive, negative=negative, + steering_token_index=-2, ) ) return mwe_dataset diff --git a/repepo/experiments/configs/believes-abortion-should-be-illegal.yaml b/repepo/experiments/configs/believes-abortion-should-be-illegal.yaml new file mode 100644 index 00000000..8b302e94 --- /dev/null +++ b/repepo/experiments/configs/believes-abortion-should-be-illegal.yaml @@ -0,0 +1,15 @@ + +use_base_model: False +model_size: 7b +train_dataset_name: believes-abortion-should-be-illegal +train_split_name: train-dev + +formatter: llama-chat-formatter +aggregator: mean +layers: [0, 15, 31] +multipliers: [-1, -0.5, 0, 0.5, 1] + +test_dataset_name: believes-abortion-should-be-illegal +test_split_name: val-dev +test_completion_template: "{prompt} My answer is: {response}" +verbose: True diff --git a/repepo/experiments/configs/desire-for-recursive-self-improvement.yaml b/repepo/experiments/configs/desire-for-recursive-self-improvement.yaml new file mode 100644 index 00000000..925f3e71 --- /dev/null +++ b/repepo/experiments/configs/desire-for-recursive-self-improvement.yaml @@ -0,0 +1,15 @@ + +use_base_model: False +model_size: 7b +train_dataset_name: desire-for-recursive-self-improvement +train_split_name: train-dev + +formatter: llama-chat-formatter +aggregator: mean +layers: [0, 15, 31] +multipliers: [-1, -0.5, 0, 0.5, 1] + +test_dataset_name: desire-for-recursive-self-improvement +test_split_name: val-dev +test_completion_template: "{prompt} My answer is: {response}" +verbose: True diff --git a/repepo/experiments/configs/machiavellianism.yaml b/repepo/experiments/configs/machiavellianism.yaml new file mode 100644 index 00000000..d8908579 --- /dev/null +++ b/repepo/experiments/configs/machiavellianism.yaml @@ -0,0 +1,15 @@ + +use_base_model: False +model_size: 7b +train_dataset_name: machiavellianism +train_split_name: train-dev + +formatter: llama-chat-formatter +aggregator: mean +layers: [0, 15, 31] +multipliers: [-1, -0.5, 0, 0.5, 1] + +test_dataset_name: machiavellianism +test_split_name: val-dev +test_completion_template: "{prompt} My answer is: {response}" +verbose: True diff --git a/repepo/experiments/configs/sycophancy.yaml b/repepo/experiments/configs/sycophancy.yaml new file mode 100644 index 00000000..b695bc6f --- /dev/null +++ b/repepo/experiments/configs/sycophancy.yaml @@ -0,0 +1,15 @@ + +use_base_model: False +model_size: 7b +train_dataset_name: sycophancy_train +train_split_name: train-dev + +formatter: llama-chat-formatter +aggregator: mean +layers: [0, 15, 31] +multipliers: [-1, -0.5, 0, 0.5, 1] + +test_dataset_name: sycophancy_train +test_split_name: val-dev +test_completion_template: "{prompt} My answer is: {response}" +verbose: True diff --git a/repepo/experiments/configs/truthfulqa.yaml b/repepo/experiments/configs/truthfulqa.yaml new file mode 100644 index 00000000..52f72f10 --- /dev/null +++ b/repepo/experiments/configs/truthfulqa.yaml @@ -0,0 +1,15 @@ + +use_base_model: False +model_size: 7b +train_dataset_name: truthfulqa +train_split_name: train-dev + +formatter: llama-chat-formatter +aggregator: mean +layers: [0, 15, 31] +multipliers: [-1, -0.5, 0, 0.5, 1] + +test_dataset_name: truthfulqa +test_split_name: val-dev +test_completion_template: "{prompt} My answer is: {response}" +verbose: True diff --git a/repepo/experiments/configs/willingness-to-be-non-HHH-to-be-deployed-in-the-real-world.yaml b/repepo/experiments/configs/willingness-to-be-non-HHH-to-be-deployed-in-the-real-world.yaml new file mode 100644 index 00000000..1f5de5f6 --- /dev/null +++ b/repepo/experiments/configs/willingness-to-be-non-HHH-to-be-deployed-in-the-real-world.yaml @@ -0,0 +1,15 @@ + +use_base_model: False +model_size: 7b +train_dataset_name: willingness-to-be-non-HHH-to-be-deployed-in-the-real-world +train_split_name: train-dev + +formatter: llama-chat-formatter +aggregator: mean +layers: [0, 15, 31] +multipliers: [-1, -0.5, 0, 0.5, 1] + +test_dataset_name: willingness-to-be-non-HHH-to-be-deployed-in-the-real-world +test_split_name: val-dev +test_completion_template: "{prompt} My answer is: {response}" +verbose: True diff --git a/repepo/experiments/run_sweep.py b/repepo/experiments/run_sweep.py new file mode 100644 index 00000000..428f3ba3 --- /dev/null +++ b/repepo/experiments/run_sweep.py @@ -0,0 +1,27 @@ +import matplotlib.pyplot as plt +from repepo.steering.run_experiment import run_experiment +from repepo.steering.utils.helpers import SteeringConfig, load_results +from repepo.steering.plot_results_by_layer import plot_results_by_layer + + +def run_sweep(configs: list[SteeringConfig], suffix=""): + for config in configs: + run_experiment(config) + + # load results + all_results = [] + for config in configs: + results = load_results(config) + all_results.append((config, results)) + + # plot results + fig, axs = plt.subplots(len(all_results), 1, figsize=(10, 20)) + for i, (config, results) in enumerate(all_results): + ax = axs[i] + plot_results_by_layer(ax, config, results) + ax.set_title(f"{config.train_dataset_name}") + + fig.tight_layout() + if suffix: + suffix = f"_{suffix}" + fig.savefig(f"results{suffix}.png") diff --git a/repepo/experiments/sweeps/persona_and_token.sh b/repepo/experiments/sweeps/persona_and_token.sh new file mode 100644 index 00000000..dc347888 --- /dev/null +++ b/repepo/experiments/sweeps/persona_and_token.sh @@ -0,0 +1,3 @@ +#!/bin/bash +python repepo/experiments/sweeps/run_sweep_persona_concepts.py +python repepo/experiments/sweeps/run_sweep_token_concepts.py diff --git a/repepo/experiments/sweeps/run_sweep_persona_concepts.py b/repepo/experiments/sweeps/run_sweep_persona_concepts.py new file mode 100644 index 00000000..01b070b1 --- /dev/null +++ b/repepo/experiments/sweeps/run_sweep_persona_concepts.py @@ -0,0 +1,35 @@ +from repepo.experiments.run_sweep import run_sweep +from repepo.steering.utils.helpers import SteeringConfig + + +def list_configs(): + datasets = [ + "machiavellianism", + "desire-for-recursive-self-improvement", + "sycophancy_train", + "willingness-to-be-non-HHH-to-be-deployed-in-the-real-world", + # "truthful_qa", + "believes-abortion-should-be-illegal", + ] + + return [ + SteeringConfig( + use_base_model=False, + model_size="7b", + train_dataset_name=dataset_name, + train_split_name="train-dev", + formatter="llama-chat-formatter", + aggregator="mean", + verbose=True, + layers=[0, 11, 12, 13, 14, 15, 31], + multipliers=[-1, -0.5, 0, 0.5, 1], + test_dataset_name=dataset_name, + test_split_name="val-dev", + test_completion_template="{prompt} My answer is: {response}", + ) + for dataset_name in datasets + ] + + +if __name__ == "__main__": + run_sweep(list_configs(), "persona_concepts") diff --git a/repepo/experiments/sweeps/run_sweep_token_concepts.py b/repepo/experiments/sweeps/run_sweep_token_concepts.py new file mode 100644 index 00000000..a01e5886 --- /dev/null +++ b/repepo/experiments/sweeps/run_sweep_token_concepts.py @@ -0,0 +1,34 @@ +from repepo.experiments.run_sweep import run_sweep +from repepo.steering.utils.helpers import SteeringConfig + + +def list_configs(): + datasets = [ + "D02 [un+adj_reg]", + "D07 [verb+able_reg]", + "E01 [country - capital]", + "E06 [animal - young]", + "I01 [noun - plural_reg]", + "I07 [verb_inf - Ved]", + ] + + return [ + SteeringConfig( + use_base_model=False, + model_size="7b", + train_dataset_name=dataset_name, + train_split_name="train-dev", + formatter="identity-formatter", + aggregator="mean", + verbose=True, + layers=[0, 11, 12, 13, 14, 15, 31], + multipliers=[-1, -0.5, 0, 0.5, 1], + test_dataset_name=dataset_name, + test_split_name="val-dev", + ) + for dataset_name in datasets + ] + + +if __name__ == "__main__": + run_sweep(list_configs(), "token_concepts") diff --git a/repepo/steering/build_steering_training_data.py b/repepo/steering/build_steering_training_data.py new file mode 100644 index 00000000..48036415 --- /dev/null +++ b/repepo/steering/build_steering_training_data.py @@ -0,0 +1,41 @@ +import logging +from repepo.core.types import Dataset +from repepo.core.pipeline import Pipeline +from steering_vectors import SteeringVectorTrainingSample + + +def _validate_train_dataset(dataset: Dataset): + steering_token_index = dataset[0].steering_token_index + for example in dataset: + assert example.steering_token_index == steering_token_index + + +def build_steering_vector_training_data( + pipeline: Pipeline, + dataset: Dataset, + logger: logging.Logger | None = None, +) -> list[SteeringVectorTrainingSample]: + # Validate that all examples have the same steering token index + _validate_train_dataset(dataset) + # After validation, we can assume that all examples have the same steering token index + read_token_index = dataset[0].steering_token_index + + # NOTE(dtch1997): Using SteeringVectorTrainingSample here + # to encode information about token index + steering_vector_training_data = [ + SteeringVectorTrainingSample( + positive_str=pipeline.build_full_prompt(example.positive), + negative_str=pipeline.build_full_prompt(example.negative), + read_positive_token_index=read_token_index, + read_negative_token_index=read_token_index, + ) + for example in dataset + ] + + if logger is not None: + # Log first example + datum = steering_vector_training_data[0] + logger.info(f"Positive example: \n {datum.positive_str}") + logger.info(f"Negative example: \n {datum.negative_str}") + + return steering_vector_training_data diff --git a/repepo/steering/evaluate_steering_vector.py b/repepo/steering/evaluate_steering_vector.py new file mode 100644 index 00000000..151e8f5e --- /dev/null +++ b/repepo/steering/evaluate_steering_vector.py @@ -0,0 +1,82 @@ +import logging + +from repepo.steering.utils.helpers import SteeringResult +from repepo.core.types import Dataset +from repepo.core.pipeline import Pipeline +from repepo.core.evaluate import ( + update_completion_template_at_eval, + select_repe_layer_at_eval, + set_repe_direction_multiplier_at_eval, + evaluate, + MultipleChoiceAccuracyEvaluator, + LogitDifferenceEvaluator, + NormalizedPositiveProbabilityEvaluator, +) +from repepo.core.hook import SteeringHook +from steering_vectors import SteeringVector, guess_and_enhance_layer_config + + +def evaluate_steering_vector( + pipeline: Pipeline, + steering_vector: SteeringVector, + dataset: Dataset, + layers: list[int], + multipliers: list[float], + patch_generation_tokens_only: bool = True, + skip_first_n_generation_tokens: int = 0, + completion_template: str | None = None, + logger: logging.Logger | None = None, +) -> list[SteeringResult]: + results = [] + + # Create steering hook and add it to pipeline + steering_hook = SteeringHook( + steering_vector=steering_vector, + direction_multiplier=0, + patch_generation_tokens_only=patch_generation_tokens_only, + skip_first_n_generation_tokens=skip_first_n_generation_tokens, + layer_config=guess_and_enhance_layer_config(pipeline.model), + ) + pipeline.hooks.append(steering_hook) + + for layer_id in layers: + for multiplier in multipliers: + eval_hooks = [ + set_repe_direction_multiplier_at_eval(multiplier), + select_repe_layer_at_eval(layer_id), + ] + if completion_template is not None: + eval_hooks.append( + update_completion_template_at_eval(completion_template) + ) + + # Run evaluate to get metrics + result = evaluate( + pipeline, + dataset, + eval_hooks=eval_hooks, + evaluators=[ + MultipleChoiceAccuracyEvaluator(), + LogitDifferenceEvaluator(), + NormalizedPositiveProbabilityEvaluator(), + ], + logger=logger, + ) + + result = SteeringResult( + layer_id=layer_id, + multiplier=multiplier, + mcq_acc=result.metrics["mcq_acc"], + logit_diff=result.metrics["logit_diff"], + pos_prob=result.metrics["pos_prob"], + ) + results.append(result) + if logger is not None: + logger.info( + f"Layer {layer_id}, multiplier {multiplier:.2f}: " + f"MCQ Accuracy {result.mcq_acc:.2f} " + f"Positive Prob {result.pos_prob:.2f} " + f"Logit Diff {result.logit_diff:.2f} " + ) + + return results diff --git a/repepo/steering/get_aggregator.py b/repepo/steering/get_aggregator.py new file mode 100644 index 00000000..40dbb8b1 --- /dev/null +++ b/repepo/steering/get_aggregator.py @@ -0,0 +1,17 @@ +from steering_vectors.aggregators import ( + Aggregator, + mean_aggregator, + logistic_aggregator, + pca_aggregator, +) + +aggregators = { + "mean": mean_aggregator, + "logistic": logistic_aggregator, + "pca": pca_aggregator, +} + + +def get_aggregator(name: str) -> Aggregator: + """A wrapper around steering_vectors.aggregators.get_aggregator""" + return aggregators[name]() diff --git a/repepo/steering/plot_results_by_layer.py b/repepo/steering/plot_results_by_layer.py new file mode 100644 index 00000000..a838f6e3 --- /dev/null +++ b/repepo/steering/plot_results_by_layer.py @@ -0,0 +1,26 @@ +import matplotlib.pyplot as plt +from repepo.steering.utils.helpers import SteeringConfig, SteeringResult + + +def plot_results_by_layer( + ax: plt.Axes, config: SteeringConfig, results: list[SteeringResult] +): + for i, layer in enumerate(config.layers): + layer_results = [x for x in results if x.layer_id == layer] + layer_results.sort(key=lambda x: x.multiplier) + + ax.plot( + [x.multiplier for x in layer_results], + [x.logit_diff for x in layer_results], + marker="o", + linestyle="dashed", + markersize=5, + linewidth=2.5, + label=f"Layer {layer}", + ) + + ax.set_title(f"{config.train_dataset_name}") + ax.set_xlabel("Multiplier") + ax.set_ylabel("Mean logit difference") + ax.legend() + return ax diff --git a/repepo/steering/run_experiment.py b/repepo/steering/run_experiment.py new file mode 100644 index 00000000..915fad5b --- /dev/null +++ b/repepo/steering/run_experiment.py @@ -0,0 +1,138 @@ +""" Defines a workflow to run a steering experiment. + +Example usage: +python repepo/steering/run_experiment.py --config_path repepo/experiments/configs/sycophancy.yaml +""" + +import matplotlib.pyplot as plt +import logging +import sys + +from pprint import pformat +from repepo.core.pipeline import Pipeline +from repepo.steering.utils.helpers import ( + SteeringConfig, + EmptyTorchCUDACache, + get_model_name, + get_model_and_tokenizer, + get_formatter, + make_dataset, + save_results, + load_results, + get_results_path, +) + +from repepo.steering.build_steering_training_data import ( + build_steering_vector_training_data, +) +from steering_vectors.train_steering_vector import ( + extract_activations, + aggregate_activations, + SteeringVector, +) + +from repepo.steering.get_aggregator import get_aggregator +from repepo.steering.evaluate_steering_vector import ( + evaluate_steering_vector, +) +from repepo.steering.plot_results_by_layer import plot_results_by_layer + + +def setup_logger() -> logging.Logger: + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + # print to stdout + handler = logging.StreamHandler(sys.stdout) + handler.setLevel(logging.INFO) + # Create a formatter and set it for the handler + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + handler.setFormatter(formatter) + + # Add the handler to the logger + logger.addHandler(handler) + + return logger + + +def run_plot_results(config): + fig, ax = plt.subplots() + results = load_results(config) + plot_results_by_layer(ax, config, results) + fig.tight_layout() + + save_path = f"results_{config.make_save_suffix()}.png" + print("Saving results to: ", save_path) + fig.savefig(save_path) + + +def run_experiment(config: SteeringConfig): + logger = setup_logger() + logger.info(f"Running experiment with config: \n{pformat(config)}") + + if get_results_path(config).exists(): + logger.info(f"Results already exist for {config}. Skipping.") + return + + # Set up pipeline + model_name = get_model_name(config.use_base_model, config.model_size) + model, tokenizer = get_model_and_tokenizer(model_name) + formatter = get_formatter(config.formatter) + pipeline = Pipeline(model, tokenizer, formatter=formatter) + + # Set up train dataset + train_dataset = make_dataset(config.train_dataset_name, config.train_split_name) + steering_vector_training_data = build_steering_vector_training_data( + pipeline, train_dataset, logger=logger + ) + + # Extract activations + with EmptyTorchCUDACache(): + pos_acts, neg_acts = extract_activations( + pipeline.model, + pipeline.tokenizer, + steering_vector_training_data, + show_progress=True, + move_to_cpu=True, + ) + + # TODO: compute intermediate metrics + + # Aggregate activations + aggregator = get_aggregator(config.aggregator) + with EmptyTorchCUDACache(): + agg_acts = aggregate_activations( + pos_acts, + neg_acts, + aggregator, + ) + steering_vector = SteeringVector( + layer_activations=agg_acts, + # TODO: make config option? + layer_type="decoder_block", + ) + + # Evaluate steering vector + test_dataset = make_dataset(config.test_dataset_name, config.test_split_name) + with EmptyTorchCUDACache(): + results = evaluate_steering_vector( + pipeline=pipeline, + steering_vector=steering_vector, + dataset=test_dataset, + layers=config.layers, + multipliers=config.multipliers, + completion_template=config.test_completion_template, + logger=logger, + ) + + # Save results + save_results(config, results) + + +if __name__ == "__main__": + import simple_parsing + + config = simple_parsing.parse(config_class=SteeringConfig, add_config_path_arg=True) + run_experiment(config) + run_plot_results(config) diff --git a/repepo/steering/utils/helpers.py b/repepo/steering/utils/helpers.py new file mode 100644 index 00000000..783babca --- /dev/null +++ b/repepo/steering/utils/helpers.py @@ -0,0 +1,275 @@ +import os +import pickle +import pathlib +import torch +import pandas as pd +import gc +import hashlib + +from typing import cast, Hashable, Literal +from dataclasses import dataclass +from pyrallis import field +from transformers import AutoModelForCausalLM, AutoTokenizer + +from transformers import BitsAndBytesConfig +from repepo.core.types import Example +from repepo.core.format import Formatter +from repepo.data.make_dataset import ( + DatasetSpec, + make_dataset as _make_dataset, +) +from repepo.steering.utils.variables import ( + WORK_DIR, + LOAD_IN_4BIT, + LOAD_IN_8BIT, + DATASET_DIR, +) + +token = os.getenv("HF_TOKEN") +experiment_suite = os.getenv("EXPERIMENT_SUITE", "steering-vectors") + +SPLITS = { + "train": "0%:40%", + "val": "40%:50%", + "test": "50:100%", + # For development, use 10 examples per split + "train-dev": "0%:+10", + "val-dev": "40%:+10", + "test-dev": "50:+10%", +} + + +class EmptyTorchCUDACache: + """Context manager to free GPU memory""" + + def __enter__(self, *args, **kwargs): + return self + + def __exit__(self, *args, **kwargs): + gc.collect() + torch.cuda.empty_cache() + + +LayerwiseMetricsDict = dict[int, float] +LayerwiseConceptVectorsDict = dict[int, torch.Tensor] + + +def pretty_print_example(example: Example): + print("Not implemented") + + +def get_model_name(use_base_model: bool, model_size: str): + """Gets model name for Llama-[7b,13b], base model or chat model""" + if use_base_model: + model_name = f"meta-llama/Llama-2-{model_size}-hf" + else: + model_name = f"meta-llama/Llama-2-{model_size}-chat-hf" + return model_name + + +def get_model_and_tokenizer( + model_name: str, + load_in_4bit: bool = bool(LOAD_IN_4BIT), + load_in_8bit: bool = bool(LOAD_IN_8BIT), +): + bnb_config = BitsAndBytesConfig( + load_in_4bit=load_in_4bit, + load_in_8bit=load_in_8bit, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_quant_type="nf4", + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name, token=token) + # Note: you must have installed 'accelerate', 'bitsandbytes' to load in 8bit + model = AutoModelForCausalLM.from_pretrained( + model_name, token=token, quantization_config=bnb_config, device_map="auto" + ) + return model, tokenizer + + +def get_formatter( + formatter_name: str = "llama-chat-formatter", +) -> Formatter: + if formatter_name == "llama-chat-formatter": + from repepo.core.format import LlamaChatFormatter + + return LlamaChatFormatter() + elif formatter_name == "identity-formatter": + from repepo.core.format import IdentityFormatter + + return IdentityFormatter() + else: + raise ValueError(f"Unknown formatter: {formatter_name}") + + +_layers = [0, 15, 31] # only the first, middle, and last layer of llama-7b +_multipliers = [-1, -0.5, 0, 0.5, 1] + + +@dataclass +class SteeringConfig: + use_base_model: bool = field(default=False) + model_size: str = field(default="7b") + train_dataset_name: str = field(default="sycophancy_train") + train_split_name: str = field(default="train-dev") + formatter: str = field(default="llama-chat-formatter") + aggregator: str = field(default="mean") + verbose: bool = True + layers: list[int] = field(default=_layers, is_mutable=True) + multipliers: list[float] = field(default=_multipliers, is_mutable=True) + test_dataset_name: str = field(default="sycophancy_train") + test_split_name: str = field(default="val-dev") + test_completion_template: str = field(default="{prompt} {response}") + + def make_save_suffix(self) -> str: + # TODO: any way to loop over fields instead of hardcoding? + str = ( + f"use-base-model={self.use_base_model}_" + f"model-size={self.model_size}_" + f"formatter={self.formatter}_" + f"train-dataset={self.train_dataset_name}_" + f"train-split={self.train_split_name}_" + f"aggregator={self.aggregator}" + f"layers={self.layers}_" + f"multipliers={self.multipliers}_" + f"test-dataset={self.test_dataset_name}_" + f"test-split={self.test_split_name}_" + f"test-completion-template={self.test_completion_template}" + ) + return hashlib.md5(str.encode()).hexdigest() + + +def get_experiment_path( + experiment_suite: str = experiment_suite, +) -> pathlib.Path: + return WORK_DIR / experiment_suite + + +ActivationType = Literal["positive", "negative", "difference"] + + +def save_activations( + config: SteeringConfig, + activations: dict[int, list[torch.Tensor]], + activation_type: ActivationType = "difference", +): + experiment_path = get_experiment_path() + result_save_suffix = config.make_save_suffix() + activations_save_dir = experiment_path / "activations" + activations_save_dir.mkdir(parents=True, exist_ok=True) + torch.save( + activations, + activations_save_dir / f"activations_{activation_type}_{result_save_suffix}.pt", + ) + + +def load_activations( + config: SteeringConfig, activation_type: ActivationType = "difference" +) -> dict[int, list[torch.Tensor]]: + experiment_path = get_experiment_path() + result_save_suffix = config.make_save_suffix() + activations_save_dir = experiment_path / "activations" + return torch.load( + activations_save_dir / f"activations_{activation_type}_{result_save_suffix}.pt" + ) + + +def save_concept_vectors( + config: SteeringConfig, concept_vectors: dict[int, torch.Tensor] +): + experiment_path = get_experiment_path() + result_save_suffix = config.make_save_suffix() + activations_save_dir = experiment_path / "vectors" + activations_save_dir.mkdir(parents=True, exist_ok=True) + torch.save( + concept_vectors, + activations_save_dir / f"concept_vectors_{result_save_suffix}.pt", + ) + + +def save_metrics( + config: SteeringConfig, + metric_name: str, + metrics: dict[Hashable, float], +): + experiment_path = get_experiment_path() + result_save_suffix = config.make_save_suffix() + metrics_save_dir = experiment_path / "metrics" + metrics_save_dir.mkdir(parents=True, exist_ok=True) + torch.save( + metrics, + metrics_save_dir / f"{metric_name}_{result_save_suffix}.pt", + ) + + +def load_metrics( + config: SteeringConfig, + metric_name: str, +) -> LayerwiseMetricsDict: + """Load layer-wise metrics for a given metric_name and config.""" + experiment_path = get_experiment_path() + result_save_suffix = config.make_save_suffix() + metrics_save_dir = experiment_path / "metrics" + return torch.load(metrics_save_dir / f"{metric_name}_{result_save_suffix}.pt") + + +def load_concept_vectors( + config: SteeringConfig, +) -> LayerwiseConceptVectorsDict: + """Load layer-wise concept vectors for a given config.""" + experiment_path = get_experiment_path() + result_save_suffix = config.make_save_suffix() + activations_save_dir = experiment_path / "vectors" + return torch.load(activations_save_dir / f"concept_vectors_{result_save_suffix}.pt") + + +@dataclass +class SteeringResult: + layer_id: int + multiplier: float + mcq_acc: float + logit_diff: float + pos_prob: float + + +def save_results( + config: SteeringConfig, + results: list[SteeringResult], +): + experiment_path = get_experiment_path() + result_save_suffix = config.make_save_suffix() + results_save_dir = experiment_path / "results" + results_save_dir.mkdir(parents=True, exist_ok=True) + with open(results_save_dir / f"results_{result_save_suffix}.pickle", "wb") as f: + pickle.dump(results, f) + + +def get_results_path(config: SteeringConfig) -> pathlib.Path: + experiment_path = get_experiment_path() + result_save_suffix = config.make_save_suffix() + results_save_dir = experiment_path / "results" + return results_save_dir / f"results_{result_save_suffix}.pickle" + + +def load_results( + config: SteeringConfig, +) -> list[SteeringResult]: + experiment_path = get_experiment_path() + result_save_suffix = config.make_save_suffix() + results_save_dir = experiment_path / "results" + with open(results_save_dir / f"results_{result_save_suffix}.pickle", "rb") as f: + return cast(list[SteeringResult], pickle.load(f)) + + +def make_dataset( + name: str, + split_name: str = "train", +): + if split_name not in SPLITS: + raise ValueError(f"Unknown split name: {split_name}") + return _make_dataset(DatasetSpec(name=name, split=SPLITS[split_name]), DATASET_DIR) + + +def convert_to_dataframe(example_list: list[Example]) -> pd.DataFrame: + df = pd.DataFrame([vars(example) for example in example_list]) + return df diff --git a/repepo/steering/utils/variables.py b/repepo/steering/utils/variables.py new file mode 100644 index 00000000..528aed23 --- /dev/null +++ b/repepo/steering/utils/variables.py @@ -0,0 +1,9 @@ +import os +from pathlib import Path +from repepo.variables import Environ + +WORK_DIR = Path(os.getenv("WORK_DIR", ".")).absolute() +DATASET_DIR = Path(Environ.DatasetDir) + +LOAD_IN_4BIT: int = int(os.getenv("LOAD_IN_4BIT", 0)) +LOAD_IN_8BIT: int = int(os.getenv("LOAD_IN_8BIT", 0)) diff --git a/tests/conftest.py b/tests/conftest.py index 712a6cc3..d7b6661f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import pytest from pytest_mock import MockerFixture +from typing import Literal import torch from transformers import ( @@ -12,7 +13,7 @@ from repepo.core.types import Tokenizer -_device: str = "cuda" if torch.cuda.is_available() else "cpu" +_device: Literal["cpu", "cuda"] = "cuda" if torch.cuda.is_available() else "cpu" # mock openai.OpenAI for testing @@ -22,7 +23,7 @@ def get_openai_client_mock(mocker: MockerFixture): @pytest.fixture -def device() -> str: +def device() -> Literal["cpu", "cuda"]: return _device @@ -79,7 +80,7 @@ def empty_llama_model() -> LlamaForCausalLM: hidden_size=1024, intermediate_size=2752, ) - return LlamaForCausalLM(config) + return LlamaForCausalLM(config).to(_device).eval() @pytest.fixture diff --git a/tests/data/multiple_choice/test_make_caa_truthfulqa.py b/tests/data/multiple_choice/test_make_caa_truthfulqa.py index a00d170b..4af3a140 100644 --- a/tests/data/multiple_choice/test_make_caa_truthfulqa.py +++ b/tests/data/multiple_choice/test_make_caa_truthfulqa.py @@ -26,6 +26,7 @@ def test_convert_hf_truthfulqa_caa_dataset_picks_first_incorrect_by_default() -> prompt="What bear is best?\n\n(A) brown bear\n(B) black bear", response="(A)", ), + steering_token_index=-2, ), Example( positive=Completion( @@ -34,6 +35,7 @@ def test_convert_hf_truthfulqa_caa_dataset_picks_first_incorrect_by_default() -> negative=Completion( prompt="What is the best color?\n\n(A) blue\n(B) red", response="(B)" ), + steering_token_index=-2, ), ] @@ -63,6 +65,7 @@ def test_convert_hf_truthfulqa_caa_dataset_can_duplicate_correct_answers() -> No prompt="What bear is best?\n\n(A) brown bear\n(B) black bear", response="(A)", ), + steering_token_index=-2, ), Example( positive=Completion( @@ -73,6 +76,7 @@ def test_convert_hf_truthfulqa_caa_dataset_can_duplicate_correct_answers() -> No prompt="What bear is best?\n\n(A) black bear\n(B) panda bear", response="(B)", ), + steering_token_index=-2, ), Example( positive=Completion( @@ -81,5 +85,6 @@ def test_convert_hf_truthfulqa_caa_dataset_can_duplicate_correct_answers() -> No negative=Completion( prompt="What is the best color?\n\n(A) red\n(B) blue", response="(A)" ), + steering_token_index=-2, ), ] diff --git a/tests/data/multiple_choice/test_make_mwe_persona.py b/tests/data/multiple_choice/test_make_mwe_persona.py index 97e41637..d0781003 100644 --- a/tests/data/multiple_choice/test_make_mwe_persona.py +++ b/tests/data/multiple_choice/test_make_mwe_persona.py @@ -32,6 +32,7 @@ def test_convert_mwe_personas_dataset_caa_english() -> None: ) assert dataset[0].positive.response == "(B)" assert dataset[0].negative.response == "(A)" + assert dataset[0].steering_token_index == -2 assert ( dataset[1].positive.prompt == dataset[1].negative.prompt and dataset[1].positive.prompt @@ -67,6 +68,7 @@ def test_convert_mwe_personas_dataset_caa_fr() -> None: ) assert dataset[0].positive.response == "(B)" assert dataset[0].negative.response == "(A)" + assert dataset[0].steering_token_index == -2 def test_convert_mwe_personas_dataset_caa_fr_with_ctx() -> None: @@ -90,3 +92,4 @@ def test_convert_mwe_personas_dataset_caa_fr_with_ctx() -> None: ) assert dataset[0].positive.response == "(B)" assert dataset[0].negative.response == "(A)" + assert dataset[0].steering_token_index == -2 diff --git a/tests/data/multiple_choice/test_make_mwe_xrisk.py b/tests/data/multiple_choice/test_make_mwe_xrisk.py index 9b4ea91e..55d470ab 100644 --- a/tests/data/multiple_choice/test_make_mwe_xrisk.py +++ b/tests/data/multiple_choice/test_make_mwe_xrisk.py @@ -23,6 +23,7 @@ def test_convert_mwe_dataset() -> None: ) assert result[0].positive.response == "(A)" assert result[0].negative.response == "(B)" + assert result[0].steering_token_index == -2 assert result[1].positive.prompt == result[1].negative.prompt assert ( result[1].positive.prompt @@ -46,3 +47,4 @@ def test_convert_mwe_data_strips_meta_tags() -> None: result[0].positive.prompt == "What is the capital of France? (A) Paris (B) London" ) + assert result[0].steering_token_index == -2 diff --git a/tests/steering/test_run_experiment.py b/tests/steering/test_run_experiment.py new file mode 100644 index 00000000..be0e2645 --- /dev/null +++ b/tests/steering/test_run_experiment.py @@ -0,0 +1,171 @@ +import torch +from repepo.core.hook import SteeringHook + +from repepo.core.format import LlamaChatFormatter +from repepo.core.types import Dataset, Example, Tokenizer, Completion +from repepo.core.pipeline import Pipeline, PipelineContext + +from repepo.steering.utils.helpers import EmptyTorchCUDACache +from repepo.steering.run_experiment import ( + build_steering_vector_training_data, + extract_activations, + get_aggregator, + aggregate_activations, + SteeringVector, +) + +from steering_vectors import guess_and_enhance_layer_config + +from transformers import LlamaForCausalLM +from tests._original_caa.llama_wrapper import LlamaWrapper + + +def test_get_steering_vector_matches_caa( + empty_llama_model: LlamaForCausalLM, llama_chat_tokenizer: Tokenizer +): + model = empty_llama_model + tokenizer = llama_chat_tokenizer + pipeline = Pipeline( + model, + tokenizer, + formatter=LlamaChatFormatter(), + ) + + #### + # First, calculate our SV + dataset: Dataset = [ + Example( + positive=Completion("Paris is in", "France"), + negative=Completion("Paris is in", "Germany"), + steering_token_index=-2, + ), + ] + steering_vector_training_data = build_steering_vector_training_data( + pipeline, dataset + ) + + layers = [0, 1, 2] + + # Extract activations + with EmptyTorchCUDACache(): + pos_acts, neg_acts = extract_activations( + pipeline.model, + pipeline.tokenizer, + steering_vector_training_data, + layers=layers, + show_progress=True, + move_to_cpu=True, + ) + + # Aggregate activations + aggregator = get_aggregator("mean") + with EmptyTorchCUDACache(): + agg_acts = aggregate_activations( + pos_acts, + neg_acts, + aggregator, + ) + steering_vector = SteeringVector( + layer_activations=agg_acts, + layer_type="decoder_block", + ) + + # hackily translated from generate_vectors.py script + tokenized_data = [ + (tokenizer.encode(svtd.positive_str), tokenizer.encode(svtd.negative_str)) + for svtd in steering_vector_training_data + ] + pos_activations = dict([(layer, []) for layer in layers]) + neg_activations = dict([(layer, []) for layer in layers]) + wrapped_model = LlamaWrapper(model, tokenizer) + + for p_tokens, n_tokens in tokenized_data: + p_tokens = torch.tensor(p_tokens).unsqueeze(0).to(model.device) + n_tokens = torch.tensor(n_tokens).unsqueeze(0).to(model.device) + wrapped_model.reset_all() + wrapped_model.get_logits(p_tokens) + for layer in layers: + p_activations = wrapped_model.get_last_activations(layer) + p_activations = p_activations[0, -2, :].detach().cpu() + pos_activations[layer].append(p_activations) + wrapped_model.reset_all() + wrapped_model.get_logits(n_tokens) + for layer in layers: + n_activations = wrapped_model.get_last_activations(layer) + n_activations = n_activations[0, -2, :].detach().cpu() + neg_activations[layer].append(n_activations) + + caa_vecs_by_layer = {} + for layer in layers: + all_pos_layer = torch.stack(pos_activations[layer]) + all_neg_layer = torch.stack(neg_activations[layer]) + caa_vecs_by_layer[layer] = (all_pos_layer - all_neg_layer).mean(dim=0) + + for layer in layers: + assert torch.allclose( + steering_vector.layer_activations[layer], + caa_vecs_by_layer[layer], + atol=1e-5, + ), f"Non-matching activations at layer {layer}" + + +def test_evaluate_steering_vector_matches_caa_llama_wrapper( + empty_llama_model: LlamaForCausalLM, llama_chat_tokenizer: Tokenizer +) -> None: + model = empty_llama_model + tokenizer = llama_chat_tokenizer + pipeline = Pipeline( + model, + tokenizer, + formatter=LlamaChatFormatter(), + ) + test_example = Example( + positive=Completion("Paris is in", "France"), + negative=Completion("Paris is in", "Germany"), + steering_token_index=-2, + ) + + layers = [0, 1, 2] + multiplier = 7 + + # Create a dummy SV + steering_vector = SteeringVector( + layer_activations={ + layer: torch.randn(1024, device=model.device) for layer in layers + } + ) + hook = SteeringHook( + steering_vector=steering_vector, + direction_multiplier=multiplier, + patch_generation_tokens_only=True, + skip_first_n_generation_tokens=1, + layer_config=guess_and_enhance_layer_config(pipeline.model), + ) + + # hackily recreating what the pipeline does during logprobs + base_prompt = pipeline.build_generation_prompt(test_example.positive) + full_prompt = pipeline.build_full_prompt(test_example.positive) + inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) + ctx = PipelineContext( + method="logprobs", + base_prompt=base_prompt, + full_prompt=full_prompt, + inputs=inputs, + pipeline=pipeline, + ) + orig_logits = model(**inputs).logits + with hook(ctx): + our_logits = model(**inputs).logits + + assert isinstance(hook, SteeringHook) # keep pyright happy + wrapped_model = LlamaWrapper(model, tokenizer, add_only_after_end_str=True) + wrapped_model.reset_all() + for layer in layers: + wrapped_model.set_add_activations( + layer, multiplier * hook.steering_vector.layer_activations[layer] + ) + + caa_logits = wrapped_model.get_logits(inputs["input_ids"].to(model.device)) + # only the final answer tokens should be different + assert torch.allclose(our_logits[0, :-2], orig_logits[0, :-2]) + assert torch.allclose(our_logits, caa_logits)