Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalization experiments #96

Merged
merged 58 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
d69077f
Add functions to do translation
dtch1997 Jan 29, 2024
37c5e39
Add TQA translate
dtch1997 Jan 30, 2024
6f694d5
Fix key name bug
dtch1997 Jan 31, 2024
774d712
WIP
dtch1997 Jan 31, 2024
79d660c
Merge branch 'main' into tqa_translate
dtch1997 Jan 31, 2024
8d1475e
Add script to generate TQA translated datasets
dtch1997 Jan 31, 2024
9e19eb6
update expt name and dataset splits
dtch1997 Jan 31, 2024
36f2914
Add Llama chat formatter
dtch1997 Jan 31, 2024
9548a51
Minor fixes in caa_repro
dtch1997 Jan 31, 2024
4891c1e
Add options to print output, save steering vectors
dtch1997 Jan 31, 2024
01594ca
Set default experiment path by train / test datasets
dtch1997 Jan 31, 2024
56b5201
Add functionality to print examples
dtch1997 Jan 31, 2024
2181328
Add script to plot results
dtch1997 Jan 31, 2024
dce925a
Add title to plotting code
dtch1997 Jan 31, 2024
0a70e84
Fix pdm lock
dtch1997 Feb 2, 2024
11306b7
Add (very ugly) function to plot multiple results
dtch1997 Feb 2, 2024
88104e0
Ignore png files
dtch1997 Feb 2, 2024
123407d
Enable translated system prompt
dtch1997 Feb 2, 2024
f59b090
Add new experiments dir
dtch1997 Feb 4, 2024
f198ab0
Add notebook to analyze TQA vectors
dtch1997 Feb 4, 2024
e67c49f
Add script to download datasets
dtch1997 Feb 4, 2024
d54bdfd
Add script to download datasets
dtch1997 Feb 4, 2024
8472eca
WIP translate
dtch1997 Feb 4, 2024
860c01b
Add code to extract and save steering vectors
dtch1997 Feb 4, 2024
f1b0f9e
Update experiments
dtch1997 Feb 5, 2024
ab47ac6
Add more dataset names
dtch1997 Feb 5, 2024
55e2ff7
Improve dataset inspection
dtch1997 Feb 5, 2024
434924f
Modify script to extract all SVs
dtch1997 Feb 5, 2024
7e61c82
Changes to notebooks
dtch1997 Feb 5, 2024
f1deec4
Update readme
dtch1997 Feb 5, 2024
0ed4d28
WIP
dtch1997 Feb 5, 2024
2788fe1
Fix download datasets
dtch1997 Feb 5, 2024
697397d
Enable 4-bit loading
dtch1997 Feb 5, 2024
e3ccc53
WIP
dtch1997 Feb 5, 2024
785c194
Visualize pairwise cos similarities
dtch1997 Feb 5, 2024
64a6b02
Inspect dataset s dataframe
dtch1997 Feb 6, 2024
0caa0eb
Clustering results
dtch1997 Feb 6, 2024
d082071
Fix lint errors
dtch1997 Feb 6, 2024
ef1327e
Add script to extract concept vectors
dtch1997 Feb 6, 2024
a7f4b12
WIP
dtch1997 Feb 6, 2024
2cc603b
Refactoring
dtch1997 Feb 6, 2024
fd26ff7
Refactoring
dtch1997 Feb 6, 2024
72ac338
Add script to run all experiments
dtch1997 Feb 6, 2024
69657f1
Fix bug with results suffix
dtch1997 Feb 6, 2024
dc4f791
Uncomment some lines
dtch1997 Feb 6, 2024
9f0c645
Update README, bash script
dtch1997 Feb 6, 2024
24050dc
Restore original experiments dir
dtch1997 Feb 6, 2024
9555251
Merge branch 'main' into generalization_experiments
dtch1997 Feb 6, 2024
a9204b7
Fix lint
dtch1997 Feb 6, 2024
75844a9
Fix lint
dtch1997 Feb 6, 2024
373f501
Merge branch 'main' into generalization_experiments
chanind Feb 6, 2024
dfc1fed
Add more aggregations
dtch1997 Feb 6, 2024
6c2b08b
Fix bug in download
dtch1997 Feb 6, 2024
9e1320f
Ignore html files
dtch1997 Feb 6, 2024
5310167
Add test for data preprocessing
dtch1997 Feb 6, 2024
a27667a
Add tests for preprocessing
dtch1997 Feb 6, 2024
5c25afe
fixing black formatting issues
chanind Feb 6, 2024
4fde15c
fixing typing
chanind Feb 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,858 changes: 2,046 additions & 812 deletions pdm.lock

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ dependencies = [
"nbdime>=4.0.1",
"steering-vectors>=0.3.0",
"openai>=1.10.0",
"arrr>=1.0.4",
"spacy>=3.7.2",
"mosestokenizer>=1.2.1",
"gradio>=4.16.0",
"simple-parsing>=0.1.5",
"torchmetrics>=1.3.0.post0",
"umap-learn>=0.5.5",
]

[tool.black]
Expand Down
25 changes: 22 additions & 3 deletions repepo/algorithms/repe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Literal, Optional
from typing_extensions import override
import random
import torch
from repepo.core.pipeline import PipelineContext

from steering_vectors import (
Expand Down Expand Up @@ -97,6 +98,7 @@ class RepeReadingControl(Algorithm):
read_token_index: int
seed: int
show_progress: bool
verbose: bool

def __init__(
self,
Expand All @@ -117,6 +119,8 @@ def __init__(
# Reference: https://github.com/nrimsky/SycophancySteering/blob/25f93a1f1aad51f94288f52d01f6a10d10f42bf1/generate_vectors.py#L102C13-L102C67
read_token_index: int = -1,
show_progress: bool = True,
verbose: bool = False,
steering_vector_save_path: Optional[str] = None,
):
self.multi_answer_method = multi_answer_method
self.layer_type = layer_type
Expand All @@ -130,6 +134,8 @@ def __init__(
self.layer_config = layer_config
self.direction_multiplier = direction_multiplier
self.show_progress = show_progress
self.verbose = verbose
self.steering_vector_save_path = steering_vector_save_path

self.skip_reading = skip_reading
self.override_vector = override_vector
Expand Down Expand Up @@ -164,6 +170,17 @@ def _get_steering_vector(
repe_training_data = self._build_steering_vector_training_data(
dataset, pipeline
)
if self.verbose:
# Print a small section of the dataset
pos_example, neg_example = repe_training_data[0]
print("Example steering vector training data:")
print("Positive prompt:")
print(pos_example)
print("Negative prompt:")
print(neg_example)
for i in range(2):
print()

return train_steering_vector(
pipeline.model,
pipeline.tokenizer,
Expand All @@ -178,9 +195,7 @@ def _get_steering_vector(

@override
def run(self, pipeline: Pipeline, dataset: Dataset) -> Pipeline:
# Steering vector reading
# NOTE: The hooks read from this steering vector.

# TODO: Clean up this horrible code...
if self.override_vector is not None:
steering_vector: SteeringVector = self.override_vector
elif not self.skip_reading:
Expand All @@ -192,6 +207,10 @@ def run(self, pipeline: Pipeline, dataset: Dataset) -> Pipeline:
"Either reading or override vector must be provided for control"
)

if self.steering_vector_save_path is not None:
# TODO: Refactor into steering_vector.save
torch.save(steering_vector, self.steering_vector_save_path)

# Creating the hooks that will do steering vector control
# NOTE: How this works is that we create a context manager that creates a hook
# whenever we are in a `PipelineContext`'s scope.
Expand Down
2 changes: 2 additions & 0 deletions repepo/core/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def evaluate_benchmark(
eval_hooks: list[EvalHook] = [],
show_progress: bool = True,
tqdm_desc: str = "Evaluating",
verbose: bool = False,
) -> EvalResult:
# evaluate
return evaluate(
Expand All @@ -64,6 +65,7 @@ def evaluate_benchmark(
eval_hooks=eval_hooks,
show_progress=show_progress,
tqdm_desc=tqdm_desc,
verbose=verbose,
)


Expand Down
27 changes: 25 additions & 2 deletions repepo/core/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,20 @@
EvalHook = Callable[[Pipeline], AbstractContextManager[None]]


def print_first_example() -> EvalHook:
"""Eval hook that prints the first example"""

@contextmanager
def print_first_example_hook(pipeline: Pipeline):
try:
pipeline.print_first_example = True
yield
finally:
pipeline.print_first_example = False

return print_first_example_hook


def update_completion_template_at_eval(new_template: str) -> EvalHook:
"""Eval hook that changes the completion template for the duration of the evaluation"""

Expand Down Expand Up @@ -200,7 +214,7 @@ def score_prediction(self, prediction: EvalPrediction) -> float:

def __call__(self, predictions: Sequence[EvalPrediction]) -> dict[str, float]:
pred_results = [self.score_prediction(pred) for pred in predictions]
return {"accuracy": mean(pred_results)}
return {"average_key_prob": mean(pred_results)}


def evaluate(
Expand All @@ -213,6 +227,7 @@ def evaluate(
eval_hooks: Sequence[EvalHook] = [],
show_progress: bool = True,
tqdm_desc: str = "Evaluating",
verbose: bool = False,
) -> EvalResult:
# evaluate
predictions: list[EvalPrediction] = []
Expand All @@ -222,15 +237,23 @@ def evaluate(
for eval_hook in eval_hooks:
stack.enter_context(eval_hook(pipeline))
# TODO: support batching
for example in tqdm(dataset, disable=not show_progress, desc=tqdm_desc):
for i, example in enumerate(
tqdm(dataset, disable=not show_progress, desc=tqdm_desc)
):
generated_output = None
correct_output_probs = None
incorrect_outputs_probs = None
if requires_generation:
if i == 0 and verbose:
print("Example generation prompt:")
print(pipeline.build_generation_prompt(example))
generated_output = pipeline.generate(
example, generation_config=generation_config
)
if requires_probs:
if i == 0 and verbose:
print("Example full prompt:")
print(pipeline.build_full_prompt(example))
correct_output_probs = pipeline.calculate_output_logprobs(example)
if example.incorrect_outputs is not None:
incorrect_outputs_probs = [
Expand Down
1 change: 1 addition & 0 deletions repepo/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class Pipeline:
formatter: Formatter = field(default_factory=InputOutputFormatter)
conversation_history: list[Example] = field(default_factory=list)
hooks: list[PipelineHook] = field(default_factory=list)
print_first_example: bool = True

def build_generation_prompt(self, example: Example) -> str:
"""Build a prompt for generation"""
Expand Down
21 changes: 13 additions & 8 deletions repepo/data/make_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,17 @@ def get_all_json_filepaths(root_dir: pathlib.Path) -> List[pathlib.Path]:


# Intentionally don't cache anything here, otherwise datasets don't be available after downloading
def _get_datasets() -> dict[str, pathlib.Path]:
def _get_datasets(dataset_dir: pathlib.Path | None = None) -> dict[str, pathlib.Path]:
datasets: dict[str, pathlib.Path] = {}
for path in get_all_json_filepaths(get_dataset_dir()):
if dataset_dir is None:
dataset_dir = get_dataset_dir()
for path in get_all_json_filepaths(dataset_dir):
datasets[path.stem] = path.absolute()
return datasets


def list_datasets() -> tuple[str, ...]:
return tuple(_get_datasets().keys())
def list_datasets(dataset_dir: pathlib.Path | None = None) -> tuple[str, ...]:
return tuple(_get_datasets(dataset_dir).keys())


@dataclass
Expand All @@ -40,6 +42,9 @@ class DatasetSpec:
split: str = ":100%"
seed: int = 0

def __repr__(self) -> str:
return f"DatasetSpec(name={self.name},split={self.split},seed={self.seed})"


def _parse_split(split_string: str, length: int) -> slice:
# Define the regular expression pattern
Expand All @@ -61,8 +66,8 @@ def _parse_split(split_string: str, length: int) -> slice:
raise ValueError(f"Parse string {split_string} not recognized")


def get_dataset(name: str) -> Dataset:
datasets = _get_datasets()
def get_dataset(name: str, dataset_dir: pathlib.Path | None = None) -> Dataset:
datasets = _get_datasets(dataset_dir)
if name not in datasets:
raise ValueError(f"Unknown dataset: {name}")

Expand All @@ -84,6 +89,6 @@ def _shuffle_and_split(items: list[T], split_string: str, seed: int) -> list[T]:
return shuffled_items[split]


def make_dataset(spec: DatasetSpec):
dataset = get_dataset(spec.name)
def make_dataset(spec: DatasetSpec, dataset_dir: pathlib.Path | None = None):
dataset = get_dataset(spec.name, dataset_dir)
return _shuffle_and_split(dataset, spec.split, spec.seed)
Loading
Loading