Skip to content

Commit

Permalink
feat: steering experiments (#132)
Browse files Browse the repository at this point in the history
* Add experimental code

* fix: make_country_capital script

* feat: add code to run steering experiment

* update experiments code

* fix: add --config_path arg

* fix: config yaml parsing

* chore: add more configs

* chore: add even more configs

* refactor: plotting

* feat: add script to run sweep

* fix: do not set completion template by default

* refactor sweeps

* refactor: token concept sweep

* fix: bugbears

* chore: add comments

* fix: steering_index of datasets

* test: steering token index

* updating steering_vectors library version

* evaluate on more layers

* refactor: use steering-vectors code, log instead of print

* chore: fix docstring

* test: training, evaluating steering vectors

* fix: minor

---------

Co-authored-by: Daniel CH Tan <dtch1997@users.noreply.github.com>
Co-authored-by: David Chanin <chanindav@gmail.com>
  • Loading branch information
3 people authored Mar 7, 2024
1 parent 986e9b2 commit 8d1bd7d
Show file tree
Hide file tree
Showing 30 changed files with 1,011 additions and 14 deletions.
9 changes: 5 additions & 4 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"termcolor>=2.4.0",
"bitsandbytes>=0.42.0",
"nbdime>=4.0.1",
"steering-vectors>=0.3.0",
"steering-vectors>=0.10.0",
"openai>=1.10.0",
"arrr>=1.0.4",
"spacy>=3.7.2",
Expand Down
10 changes: 6 additions & 4 deletions repepo/core/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from repepo.core.pipeline import Pipeline

import numpy as np
import logging

# eval_hooks allow us to do custom stuff to the pipeline only during evaluation
EvalHook = Callable[[Pipeline], AbstractContextManager[None]]
Expand Down Expand Up @@ -205,7 +206,7 @@ def evaluate(
eval_hooks: Sequence[EvalHook] = [],
show_progress: bool = True,
tqdm_desc: str = "Evaluating",
verbose: bool = False,
logger: logging.Logger | None = None,
) -> EvalResult:
# evaluate
predictions: list[EvalPrediction] = []
Expand All @@ -217,9 +218,10 @@ def evaluate(
for i, example in enumerate(
tqdm(dataset, disable=not show_progress, desc=tqdm_desc)
):
if i == 0 and verbose:
print("Example full prompt:")
print(pipeline.build_full_prompt(example.positive))
if logger is not None and i == 0:
logger.info(
f"Example full prompt: \n {pipeline.build_full_prompt(example.positive)}"
)
positive_probs = pipeline.calculate_output_logprobs(example.positive)
negative_probs = pipeline.calculate_output_logprobs(example.negative)

Expand Down
2 changes: 1 addition & 1 deletion repepo/data/multiple_choice/make_bats.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def convert_bats_dataset(


def make_bats():
"""Make MWE dataset"""
"""Make BATS dataset"""
for dataset_path in get_raw_dataset_dir().glob("bats/*/*.txt"):
with open(dataset_path, "r") as file:
list_dataset = []
Expand Down
1 change: 1 addition & 0 deletions repepo/data/multiple_choice/make_caa_truthfulqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def convert_hf_truthfulqa_caa_dataset(
Example(
positive=Completion(prompt=prompt, response=correct_ans),
negative=Completion(prompt=prompt, response=incorrect_ans),
steering_token_index=-2,
)
)
return tqa_dataset
Expand Down
33 changes: 33 additions & 0 deletions repepo/data/multiple_choice/make_country_capital_with_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from repepo.core.types import Dataset
from repepo.data.multiple_choice.make_bats import (
convert_bats_dataset,
get_raw_dataset_dir,
build_dataset_filename,
get_dataset_dir,
jdump,
)

PROMPT = "I live in"


def make_country_capital_with_prompt():
dataset_path = get_raw_dataset_dir() / (
"bats/3_Encyclopedic_semantics/E01 [country - capital].txt"
)
assert dataset_path.exists()

with open(dataset_path, "r") as file:
list_dataset = []
for line in file:
element = line.strip().split()
assert len(element) == 2
list_dataset.append(tuple(element))

dataset_name = "country-capital-with-prompt"
filename = build_dataset_filename(dataset_name)
mwe_dataset: Dataset = convert_bats_dataset(list_dataset, prompt=PROMPT)
jdump(mwe_dataset, get_dataset_dir() / "bats" / filename)


if __name__ == "__main__":
make_country_capital_with_prompt()
2 changes: 1 addition & 1 deletion repepo/data/multiple_choice/make_mwe_persona.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def convert_mwe_personas_dataset_caa(
response=element["answer_not_matching_behavior"],
)

ex = Example(positive=positive, negative=negative)
ex = Example(positive=positive, negative=negative, steering_token_index=-2)
mwe_dataset.append(ex)
return mwe_dataset

Expand Down
1 change: 1 addition & 0 deletions repepo/data/multiple_choice/make_mwe_xrisk.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def convert_mwe_dataset(
Example(
positive=positive,
negative=negative,
steering_token_index=-2,
)
)
return mwe_dataset
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

use_base_model: False
model_size: 7b
train_dataset_name: believes-abortion-should-be-illegal
train_split_name: train-dev

formatter: llama-chat-formatter
aggregator: mean
layers: [0, 15, 31]
multipliers: [-1, -0.5, 0, 0.5, 1]

test_dataset_name: believes-abortion-should-be-illegal
test_split_name: val-dev
test_completion_template: "{prompt} My answer is: {response}"
verbose: True
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

use_base_model: False
model_size: 7b
train_dataset_name: desire-for-recursive-self-improvement
train_split_name: train-dev

formatter: llama-chat-formatter
aggregator: mean
layers: [0, 15, 31]
multipliers: [-1, -0.5, 0, 0.5, 1]

test_dataset_name: desire-for-recursive-self-improvement
test_split_name: val-dev
test_completion_template: "{prompt} My answer is: {response}"
verbose: True
15 changes: 15 additions & 0 deletions repepo/experiments/configs/machiavellianism.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

use_base_model: False
model_size: 7b
train_dataset_name: machiavellianism
train_split_name: train-dev

formatter: llama-chat-formatter
aggregator: mean
layers: [0, 15, 31]
multipliers: [-1, -0.5, 0, 0.5, 1]

test_dataset_name: machiavellianism
test_split_name: val-dev
test_completion_template: "{prompt} My answer is: {response}"
verbose: True
15 changes: 15 additions & 0 deletions repepo/experiments/configs/sycophancy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

use_base_model: False
model_size: 7b
train_dataset_name: sycophancy_train
train_split_name: train-dev

formatter: llama-chat-formatter
aggregator: mean
layers: [0, 15, 31]
multipliers: [-1, -0.5, 0, 0.5, 1]

test_dataset_name: sycophancy_train
test_split_name: val-dev
test_completion_template: "{prompt} My answer is: {response}"
verbose: True
15 changes: 15 additions & 0 deletions repepo/experiments/configs/truthfulqa.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

use_base_model: False
model_size: 7b
train_dataset_name: truthfulqa
train_split_name: train-dev

formatter: llama-chat-formatter
aggregator: mean
layers: [0, 15, 31]
multipliers: [-1, -0.5, 0, 0.5, 1]

test_dataset_name: truthfulqa
test_split_name: val-dev
test_completion_template: "{prompt} My answer is: {response}"
verbose: True
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

use_base_model: False
model_size: 7b
train_dataset_name: willingness-to-be-non-HHH-to-be-deployed-in-the-real-world
train_split_name: train-dev

formatter: llama-chat-formatter
aggregator: mean
layers: [0, 15, 31]
multipliers: [-1, -0.5, 0, 0.5, 1]

test_dataset_name: willingness-to-be-non-HHH-to-be-deployed-in-the-real-world
test_split_name: val-dev
test_completion_template: "{prompt} My answer is: {response}"
verbose: True
27 changes: 27 additions & 0 deletions repepo/experiments/run_sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import matplotlib.pyplot as plt
from repepo.steering.run_experiment import run_experiment
from repepo.steering.utils.helpers import SteeringConfig, load_results
from repepo.steering.plot_results_by_layer import plot_results_by_layer


def run_sweep(configs: list[SteeringConfig], suffix=""):
for config in configs:
run_experiment(config)

# load results
all_results = []
for config in configs:
results = load_results(config)
all_results.append((config, results))

# plot results
fig, axs = plt.subplots(len(all_results), 1, figsize=(10, 20))
for i, (config, results) in enumerate(all_results):
ax = axs[i]
plot_results_by_layer(ax, config, results)
ax.set_title(f"{config.train_dataset_name}")

fig.tight_layout()
if suffix:
suffix = f"_{suffix}"
fig.savefig(f"results{suffix}.png")
3 changes: 3 additions & 0 deletions repepo/experiments/sweeps/persona_and_token.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash
python repepo/experiments/sweeps/run_sweep_persona_concepts.py
python repepo/experiments/sweeps/run_sweep_token_concepts.py
35 changes: 35 additions & 0 deletions repepo/experiments/sweeps/run_sweep_persona_concepts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from repepo.experiments.run_sweep import run_sweep
from repepo.steering.utils.helpers import SteeringConfig


def list_configs():
datasets = [
"machiavellianism",
"desire-for-recursive-self-improvement",
"sycophancy_train",
"willingness-to-be-non-HHH-to-be-deployed-in-the-real-world",
# "truthful_qa",
"believes-abortion-should-be-illegal",
]

return [
SteeringConfig(
use_base_model=False,
model_size="7b",
train_dataset_name=dataset_name,
train_split_name="train-dev",
formatter="llama-chat-formatter",
aggregator="mean",
verbose=True,
layers=[0, 11, 12, 13, 14, 15, 31],
multipliers=[-1, -0.5, 0, 0.5, 1],
test_dataset_name=dataset_name,
test_split_name="val-dev",
test_completion_template="{prompt} My answer is: {response}",
)
for dataset_name in datasets
]


if __name__ == "__main__":
run_sweep(list_configs(), "persona_concepts")
34 changes: 34 additions & 0 deletions repepo/experiments/sweeps/run_sweep_token_concepts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from repepo.experiments.run_sweep import run_sweep
from repepo.steering.utils.helpers import SteeringConfig


def list_configs():
datasets = [
"D02 [un+adj_reg]",
"D07 [verb+able_reg]",
"E01 [country - capital]",
"E06 [animal - young]",
"I01 [noun - plural_reg]",
"I07 [verb_inf - Ved]",
]

return [
SteeringConfig(
use_base_model=False,
model_size="7b",
train_dataset_name=dataset_name,
train_split_name="train-dev",
formatter="identity-formatter",
aggregator="mean",
verbose=True,
layers=[0, 11, 12, 13, 14, 15, 31],
multipliers=[-1, -0.5, 0, 0.5, 1],
test_dataset_name=dataset_name,
test_split_name="val-dev",
)
for dataset_name in datasets
]


if __name__ == "__main__":
run_sweep(list_configs(), "token_concepts")
41 changes: 41 additions & 0 deletions repepo/steering/build_steering_training_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import logging
from repepo.core.types import Dataset
from repepo.core.pipeline import Pipeline
from steering_vectors import SteeringVectorTrainingSample


def _validate_train_dataset(dataset: Dataset):
steering_token_index = dataset[0].steering_token_index
for example in dataset:
assert example.steering_token_index == steering_token_index


def build_steering_vector_training_data(
pipeline: Pipeline,
dataset: Dataset,
logger: logging.Logger | None = None,
) -> list[SteeringVectorTrainingSample]:
# Validate that all examples have the same steering token index
_validate_train_dataset(dataset)
# After validation, we can assume that all examples have the same steering token index
read_token_index = dataset[0].steering_token_index

# NOTE(dtch1997): Using SteeringVectorTrainingSample here
# to encode information about token index
steering_vector_training_data = [
SteeringVectorTrainingSample(
positive_str=pipeline.build_full_prompt(example.positive),
negative_str=pipeline.build_full_prompt(example.negative),
read_positive_token_index=read_token_index,
read_negative_token_index=read_token_index,
)
for example in dataset
]

if logger is not None:
# Log first example
datum = steering_vector_training_data[0]
logger.info(f"Positive example: \n {datum.positive_str}")
logger.info(f"Negative example: \n {datum.negative_str}")

return steering_vector_training_data
Loading

0 comments on commit 8d1bd7d

Please sign in to comment.