-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add experimental code * fix: make_country_capital script * feat: add code to run steering experiment * update experiments code * fix: add --config_path arg * fix: config yaml parsing * chore: add more configs * chore: add even more configs * refactor: plotting * feat: add script to run sweep * fix: do not set completion template by default * refactor sweeps * refactor: token concept sweep * fix: bugbears * chore: add comments * fix: steering_index of datasets * test: steering token index * updating steering_vectors library version * evaluate on more layers * refactor: use steering-vectors code, log instead of print * chore: fix docstring * test: training, evaluating steering vectors * fix: minor --------- Co-authored-by: Daniel CH Tan <dtch1997@users.noreply.github.com> Co-authored-by: David Chanin <chanindav@gmail.com>
- Loading branch information
1 parent
986e9b2
commit 8d1bd7d
Showing
30 changed files
with
1,011 additions
and
14 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
33 changes: 33 additions & 0 deletions
33
repepo/data/multiple_choice/make_country_capital_with_prompt.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from repepo.core.types import Dataset | ||
from repepo.data.multiple_choice.make_bats import ( | ||
convert_bats_dataset, | ||
get_raw_dataset_dir, | ||
build_dataset_filename, | ||
get_dataset_dir, | ||
jdump, | ||
) | ||
|
||
PROMPT = "I live in" | ||
|
||
|
||
def make_country_capital_with_prompt(): | ||
dataset_path = get_raw_dataset_dir() / ( | ||
"bats/3_Encyclopedic_semantics/E01 [country - capital].txt" | ||
) | ||
assert dataset_path.exists() | ||
|
||
with open(dataset_path, "r") as file: | ||
list_dataset = [] | ||
for line in file: | ||
element = line.strip().split() | ||
assert len(element) == 2 | ||
list_dataset.append(tuple(element)) | ||
|
||
dataset_name = "country-capital-with-prompt" | ||
filename = build_dataset_filename(dataset_name) | ||
mwe_dataset: Dataset = convert_bats_dataset(list_dataset, prompt=PROMPT) | ||
jdump(mwe_dataset, get_dataset_dir() / "bats" / filename) | ||
|
||
|
||
if __name__ == "__main__": | ||
make_country_capital_with_prompt() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
15 changes: 15 additions & 0 deletions
15
repepo/experiments/configs/believes-abortion-should-be-illegal.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
|
||
use_base_model: False | ||
model_size: 7b | ||
train_dataset_name: believes-abortion-should-be-illegal | ||
train_split_name: train-dev | ||
|
||
formatter: llama-chat-formatter | ||
aggregator: mean | ||
layers: [0, 15, 31] | ||
multipliers: [-1, -0.5, 0, 0.5, 1] | ||
|
||
test_dataset_name: believes-abortion-should-be-illegal | ||
test_split_name: val-dev | ||
test_completion_template: "{prompt} My answer is: {response}" | ||
verbose: True |
15 changes: 15 additions & 0 deletions
15
repepo/experiments/configs/desire-for-recursive-self-improvement.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
|
||
use_base_model: False | ||
model_size: 7b | ||
train_dataset_name: desire-for-recursive-self-improvement | ||
train_split_name: train-dev | ||
|
||
formatter: llama-chat-formatter | ||
aggregator: mean | ||
layers: [0, 15, 31] | ||
multipliers: [-1, -0.5, 0, 0.5, 1] | ||
|
||
test_dataset_name: desire-for-recursive-self-improvement | ||
test_split_name: val-dev | ||
test_completion_template: "{prompt} My answer is: {response}" | ||
verbose: True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
|
||
use_base_model: False | ||
model_size: 7b | ||
train_dataset_name: machiavellianism | ||
train_split_name: train-dev | ||
|
||
formatter: llama-chat-formatter | ||
aggregator: mean | ||
layers: [0, 15, 31] | ||
multipliers: [-1, -0.5, 0, 0.5, 1] | ||
|
||
test_dataset_name: machiavellianism | ||
test_split_name: val-dev | ||
test_completion_template: "{prompt} My answer is: {response}" | ||
verbose: True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
|
||
use_base_model: False | ||
model_size: 7b | ||
train_dataset_name: sycophancy_train | ||
train_split_name: train-dev | ||
|
||
formatter: llama-chat-formatter | ||
aggregator: mean | ||
layers: [0, 15, 31] | ||
multipliers: [-1, -0.5, 0, 0.5, 1] | ||
|
||
test_dataset_name: sycophancy_train | ||
test_split_name: val-dev | ||
test_completion_template: "{prompt} My answer is: {response}" | ||
verbose: True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
|
||
use_base_model: False | ||
model_size: 7b | ||
train_dataset_name: truthfulqa | ||
train_split_name: train-dev | ||
|
||
formatter: llama-chat-formatter | ||
aggregator: mean | ||
layers: [0, 15, 31] | ||
multipliers: [-1, -0.5, 0, 0.5, 1] | ||
|
||
test_dataset_name: truthfulqa | ||
test_split_name: val-dev | ||
test_completion_template: "{prompt} My answer is: {response}" | ||
verbose: True |
15 changes: 15 additions & 0 deletions
15
repepo/experiments/configs/willingness-to-be-non-HHH-to-be-deployed-in-the-real-world.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
|
||
use_base_model: False | ||
model_size: 7b | ||
train_dataset_name: willingness-to-be-non-HHH-to-be-deployed-in-the-real-world | ||
train_split_name: train-dev | ||
|
||
formatter: llama-chat-formatter | ||
aggregator: mean | ||
layers: [0, 15, 31] | ||
multipliers: [-1, -0.5, 0, 0.5, 1] | ||
|
||
test_dataset_name: willingness-to-be-non-HHH-to-be-deployed-in-the-real-world | ||
test_split_name: val-dev | ||
test_completion_template: "{prompt} My answer is: {response}" | ||
verbose: True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import matplotlib.pyplot as plt | ||
from repepo.steering.run_experiment import run_experiment | ||
from repepo.steering.utils.helpers import SteeringConfig, load_results | ||
from repepo.steering.plot_results_by_layer import plot_results_by_layer | ||
|
||
|
||
def run_sweep(configs: list[SteeringConfig], suffix=""): | ||
for config in configs: | ||
run_experiment(config) | ||
|
||
# load results | ||
all_results = [] | ||
for config in configs: | ||
results = load_results(config) | ||
all_results.append((config, results)) | ||
|
||
# plot results | ||
fig, axs = plt.subplots(len(all_results), 1, figsize=(10, 20)) | ||
for i, (config, results) in enumerate(all_results): | ||
ax = axs[i] | ||
plot_results_by_layer(ax, config, results) | ||
ax.set_title(f"{config.train_dataset_name}") | ||
|
||
fig.tight_layout() | ||
if suffix: | ||
suffix = f"_{suffix}" | ||
fig.savefig(f"results{suffix}.png") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#!/bin/bash | ||
python repepo/experiments/sweeps/run_sweep_persona_concepts.py | ||
python repepo/experiments/sweeps/run_sweep_token_concepts.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from repepo.experiments.run_sweep import run_sweep | ||
from repepo.steering.utils.helpers import SteeringConfig | ||
|
||
|
||
def list_configs(): | ||
datasets = [ | ||
"machiavellianism", | ||
"desire-for-recursive-self-improvement", | ||
"sycophancy_train", | ||
"willingness-to-be-non-HHH-to-be-deployed-in-the-real-world", | ||
# "truthful_qa", | ||
"believes-abortion-should-be-illegal", | ||
] | ||
|
||
return [ | ||
SteeringConfig( | ||
use_base_model=False, | ||
model_size="7b", | ||
train_dataset_name=dataset_name, | ||
train_split_name="train-dev", | ||
formatter="llama-chat-formatter", | ||
aggregator="mean", | ||
verbose=True, | ||
layers=[0, 11, 12, 13, 14, 15, 31], | ||
multipliers=[-1, -0.5, 0, 0.5, 1], | ||
test_dataset_name=dataset_name, | ||
test_split_name="val-dev", | ||
test_completion_template="{prompt} My answer is: {response}", | ||
) | ||
for dataset_name in datasets | ||
] | ||
|
||
|
||
if __name__ == "__main__": | ||
run_sweep(list_configs(), "persona_concepts") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from repepo.experiments.run_sweep import run_sweep | ||
from repepo.steering.utils.helpers import SteeringConfig | ||
|
||
|
||
def list_configs(): | ||
datasets = [ | ||
"D02 [un+adj_reg]", | ||
"D07 [verb+able_reg]", | ||
"E01 [country - capital]", | ||
"E06 [animal - young]", | ||
"I01 [noun - plural_reg]", | ||
"I07 [verb_inf - Ved]", | ||
] | ||
|
||
return [ | ||
SteeringConfig( | ||
use_base_model=False, | ||
model_size="7b", | ||
train_dataset_name=dataset_name, | ||
train_split_name="train-dev", | ||
formatter="identity-formatter", | ||
aggregator="mean", | ||
verbose=True, | ||
layers=[0, 11, 12, 13, 14, 15, 31], | ||
multipliers=[-1, -0.5, 0, 0.5, 1], | ||
test_dataset_name=dataset_name, | ||
test_split_name="val-dev", | ||
) | ||
for dataset_name in datasets | ||
] | ||
|
||
|
||
if __name__ == "__main__": | ||
run_sweep(list_configs(), "token_concepts") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import logging | ||
from repepo.core.types import Dataset | ||
from repepo.core.pipeline import Pipeline | ||
from steering_vectors import SteeringVectorTrainingSample | ||
|
||
|
||
def _validate_train_dataset(dataset: Dataset): | ||
steering_token_index = dataset[0].steering_token_index | ||
for example in dataset: | ||
assert example.steering_token_index == steering_token_index | ||
|
||
|
||
def build_steering_vector_training_data( | ||
pipeline: Pipeline, | ||
dataset: Dataset, | ||
logger: logging.Logger | None = None, | ||
) -> list[SteeringVectorTrainingSample]: | ||
# Validate that all examples have the same steering token index | ||
_validate_train_dataset(dataset) | ||
# After validation, we can assume that all examples have the same steering token index | ||
read_token_index = dataset[0].steering_token_index | ||
|
||
# NOTE(dtch1997): Using SteeringVectorTrainingSample here | ||
# to encode information about token index | ||
steering_vector_training_data = [ | ||
SteeringVectorTrainingSample( | ||
positive_str=pipeline.build_full_prompt(example.positive), | ||
negative_str=pipeline.build_full_prompt(example.negative), | ||
read_positive_token_index=read_token_index, | ||
read_negative_token_index=read_token_index, | ||
) | ||
for example in dataset | ||
] | ||
|
||
if logger is not None: | ||
# Log first example | ||
datum = steering_vector_training_data[0] | ||
logger.info(f"Positive example: \n {datum.positive_str}") | ||
logger.info(f"Negative example: \n {datum.negative_str}") | ||
|
||
return steering_vector_training_data |
Oops, something went wrong.