Skip to content

Commit

Permalink
interleaving positive and negative prompts to match what the original…
Browse files Browse the repository at this point in the history
… repe code expects (#59)
  • Loading branch information
chanind authored Jan 11, 2024
1 parent 856145e commit 8eb3b75
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 58 deletions.
84 changes: 61 additions & 23 deletions repepo/algorithms/repe.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from dataclasses import replace
from typing import Any, Literal, NamedTuple, Optional, cast
from typing_extensions import override
import numpy as np
import random

import torch
from repepo.core import Pipeline
from repepo.core.format import Formatter
from repepo.core.types import Dataset, Model
from repepo.core.types import Dataset, Example, Model
from repepo.repe.rep_reading_pipeline import RepReadingPipeline
from repepo.algorithms.base import Algorithm
from repepo.utils.layer_matching import (
Expand All @@ -18,6 +18,7 @@
from repepo.utils.model_patcher import LayerType, ModelPatcher

DirectionMethod = Literal["pca", "cluster_mean", "random"]
MultiAnswerMethod = Literal["first_incorrect", "random_incorrect", "repeat_correct"]

DEFAULT_READING_TEMPLATE = "Consider the amount of truthfulness in the following answer.\n\n{question} {answer}\n\nThe amount of truthfulness in the answer is "

Expand All @@ -42,36 +43,41 @@ class RepeTrainingData(NamedTuple):

class RepeDirections(NamedTuple):
activations: dict[int, torch.Tensor]
# why is this a dict of np arrays instead of a dict of torch tensors? mystery
signs: dict[int, np.ndarray]
signs: dict[int, int]


class RepeReadingControl(Algorithm):
direction_method: DirectionMethod
layer_type: LayerType
multi_answer_method: MultiAnswerMethod
reading_template: str
layers: list[int] | None
n_difference: int
batch_size: int
max_length: int
layer_config: ModelLayerConfig | None
direction_finder_kwargs: dict[str, Any]
seed: int

def __init__(
self,
reading_template: str = DEFAULT_READING_TEMPLATE,
direction_method: DirectionMethod = "pca",
layer_type: LayerType = "decoder_block",
multi_answer_method: MultiAnswerMethod = "first_incorrect",
n_difference: int = 1, # TODO: what does this do?
batch_size: int = 8,
max_length: int = 2048,
seed: int = 0,
layers: Optional[list[int]] = None,
layer_config: Optional[ModelLayerConfig] = None,
# TODO: remove this when refactoring repe reading pipeline
direction_finder_kwargs: Optional[dict[str, Any]] = None,
):
self.direction_method = direction_method
self.multi_answer_method = multi_answer_method
self.layer_type = layer_type
self.seed = seed
_validate_reading_template(reading_template)
self.reading_template = reading_template
self.layers = layers
Expand All @@ -86,24 +92,16 @@ def _build_repe_training_data_and_labels(
) -> RepeTrainingData:
prompts: list[str] = []
grouped_labels: list[list[int]] = []
for example in dataset:
if example.incorrect_outputs is None:
raise ValueError(
"RepEngReadingControl requires incorrect_outputs to be set"
)
incorrect_examples = [
replace(example, output=output) for output in example.incorrect_outputs
]
label_group = [1, *([0] * len(incorrect_examples))]
grouped_labels.append(label_group)
group_examples = [example, *incorrect_examples]
for group_example in group_examples:
completion = formatter.apply(group_example)
prompts.append(
self.reading_template.format(
question=completion.prompt, answer=completion.response
)
)
for i, example in enumerate(dataset):
example_prompts, example_labels = self._convert_example_to_repe_format(
example,
formatter,
# repe doesn't reflect differences across the origin, so we need to ensure
# an even balance of reversed and non-reversed examples
reverse_order=i % 2 == 0,
)
prompts.extend(example_prompts)
grouped_labels.append(example_labels)
return RepeTrainingData(prompts=prompts, labels=grouped_labels)

def _get_layer_matcher_for_model(self, model: Model) -> LayerMatcher:
Expand Down Expand Up @@ -154,7 +152,7 @@ def _get_directions(self, pipeline: Pipeline, dataset: Dataset) -> RepeDirection
key: torch.FloatTensor(val)
for key, val in rep_reader.directions.items()
},
signs=rep_reader.direction_signs,
signs={key: val.item() for key, val in rep_reader.direction_signs.items()},
)

@override
Expand All @@ -167,3 +165,43 @@ def run(self, pipeline: Pipeline, dataset: Dataset) -> Pipeline:
layer_type=self.layer_type,
)
return pipeline

def _convert_example_to_repe_format(
self, example: Example, formatter: Formatter, reverse_order: bool
) -> tuple[list[str], list[int]]:
"""Converts an example to the format expected by repe"""
if example.incorrect_outputs is None:
raise ValueError(
"RepEngReadingControl requires incorrect_outputs to be set"
)
incorrect_examples = [
replace(example, output=output) for output in example.incorrect_outputs
]
correct_examples = [example]
if self.multi_answer_method == "repeat_correct":
correct_examples = [example] * len(example.incorrect_outputs)
elif self.multi_answer_method == "first_incorrect":
incorrect_examples = [incorrect_examples[0]]
elif self.multi_answer_method == "random_incorrect":
rand_gen = random.Random(f"{self.seed}-{example.input}")
incorrect_examples = [rand_gen.choice(incorrect_examples)]
else:
raise ValueError(f"Unknown multi_answer_method {self.multi_answer_method}")
assert len(incorrect_examples) == len(correct_examples)
# repe (undocumentedly) expects interleaved positive and negative examples
if reverse_order:
labels = [0, 1] * len(correct_examples)
paired_examples = zip(incorrect_examples, correct_examples)
else:
labels = [1, 0] * len(correct_examples)
paired_examples = zip(correct_examples, incorrect_examples)
# interleaved pos, neg, pos neg, ...
examples = [ex for pair in paired_examples for ex in pair]
completions = [formatter.apply(ex) for ex in examples]
prompts = [
self.reading_template.format(
question=completion.prompt, answer=completion.response
)
for completion in completions
]
return prompts, labels
28 changes: 2 additions & 26 deletions tests/algorithms/__snapshots__/test_repe.ambr
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
# serializer version: 1
# name: test_RepeReadingControl_build_repe_training_data_and_labels
# name: test_RepeReadingControl_build_repe_training_data_and_labels_picks_one_neg_by_default
list([
'''
Consider the amount of truthfulness in the following answer.

Input: Paris is in
Output: France

The amount of truthfulness in the answer is
''',
'''
Consider the amount of truthfulness in the following answer.

Expand All @@ -21,7 +13,7 @@
Consider the amount of truthfulness in the following answer.

Input: Paris is in
Output: Italy
Output: France

The amount of truthfulness in the answer is
''',
Expand All @@ -39,22 +31,6 @@
Input: 1 + 1 =
Output: 11

The amount of truthfulness in the answer is
''',
'''
Consider the amount of truthfulness in the following answer.

Input: 1 + 1 =
Output: 1234

The amount of truthfulness in the answer is
''',
'''
Consider the amount of truthfulness in the following answer.

Input: 1 + 1 =
Output: 3.14

The amount of truthfulness in the answer is
''',
])
Expand Down
104 changes: 95 additions & 9 deletions tests/algorithms/test_repe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from transformers import GPTNeoXForCausalLM


def test_RepeReadingControl_build_repe_training_data_and_labels(
def test_RepeReadingControl_build_repe_training_data_and_labels_picks_one_neg_by_default(
snapshot: SnapshotAssertion,
) -> None:
dataset: Dataset = [
Expand All @@ -20,7 +20,7 @@ def test_RepeReadingControl_build_repe_training_data_and_labels(
instruction="",
input="1 + 1 =",
output="2",
incorrect_outputs=["11", "1234", "3.14"],
incorrect_outputs=["11", "34", "3.14"],
),
]
formatter = InputOutputFormatter()
Expand All @@ -29,11 +29,87 @@ def test_RepeReadingControl_build_repe_training_data_and_labels(
dataset, formatter
)
# for some reason the training data isn't grouped, but labels are. This is how it is in the original code.
assert len(training_data) == 7
assert labels == [[1, 0, 0], [1, 0, 0, 0]]
assert len(training_data) == 4
# should pick the first incorrect output only by default
assert "Germany" in training_data[0]
assert "France" in training_data[1]
assert "2" in training_data[2]
assert "11" in training_data[3]
# should alternate between flipped and non-flipped labels
assert labels == [[0, 1], [1, 0]]
assert training_data == snapshot


def test_RepeReadingControl_build_repe_training_data_and_labels_with_random_incorrect() -> (
None
):
dataset: Dataset = [
Example(
instruction="",
input="Paris is in",
output="France",
incorrect_outputs=["Germany", "Italy"],
),
Example(
instruction="",
input="1 + 1 =",
output="2",
incorrect_outputs=["11", "34", "3.14"],
),
]
formatter = InputOutputFormatter()
algorithm = RepeReadingControl()
training_data, labels = algorithm._build_repe_training_data_and_labels(
dataset, formatter
)
# for some reason the training data isn't grouped, but labels are. This is how it is in the original code.
assert len(training_data) == 4
# should pick the a random incorrect output
assert "France" in training_data[1]
assert "2" in training_data[2]
# should alternate between flipped and non-flipped labels
assert labels == [[0, 1], [1, 0]]


def test_RepeReadingControl_build_repe_training_data_and_labels_with_repeat_correct() -> (
None
):
dataset: Dataset = [
Example(
instruction="",
input="Paris is in",
output="France",
incorrect_outputs=["Germany", "Italy"],
),
Example(
instruction="",
input="1 + 1 =",
output="2",
incorrect_outputs=["11", "34", "3.14"],
),
]
formatter = InputOutputFormatter()
algorithm = RepeReadingControl(multi_answer_method="repeat_correct")
training_data, labels = algorithm._build_repe_training_data_and_labels(
dataset, formatter
)
# for some reason the training data isn't grouped, but labels are. This is how it is in the original code.
assert len(training_data) == 10
# the positive example should be repeated once for each incorrect output
assert "Germany" in training_data[0]
assert "France" in training_data[1]
assert "Italy" in training_data[2]
assert "France" in training_data[3]
assert "2" in training_data[4]
assert "11" in training_data[5]
assert "2" in training_data[6]
assert "34" in training_data[7]
assert "2" in training_data[8]
assert "3.14" in training_data[9]
# should alternate between flipped and non-flipped labels
assert labels == [[0, 1, 0, 1], [1, 0, 1, 0, 1, 0]]


def test_RepeReadingControl_get_directions(
model: GPTNeoXForCausalLM, tokenizer: Tokenizer
) -> None:
Expand All @@ -47,12 +123,14 @@ def test_RepeReadingControl_get_directions(
incorrect_outputs=["Germany", "Italy"],
),
]
algorithm = RepeReadingControl()
algorithm = RepeReadingControl(multi_answer_method="repeat_correct")
directions = algorithm._get_directions(pipeline, dataset)
assert list(directions.activations.keys()) == [-1, -2, -3, -4, -5]
assert list(directions.signs.keys()) == [-1, -2, -3, -4, -5]
for act in directions.activations.values():
assert act.shape == (1, 512)
for sign in directions.signs.values():
assert sign in [-1, 1]


def test_RepeReadingControl_run(
Expand All @@ -61,19 +139,27 @@ def test_RepeReadingControl_run(
tokenizer.pad_token_id = model.config.eos_token_id
pipeline = Pipeline(model, tokenizer)

example = Example(
test_example = Example(
instruction="",
input="Paris is in",
output="France",
incorrect_outputs=["Germany", "Italy"],
)
dataset: Dataset = [example]
dataset: Dataset = [
test_example,
Example(
instruction="",
input="1 + 1 =",
output="2",
incorrect_outputs=["11", "34", "3.14"],
),
]

original_outputs = pipeline.generate(example)
original_outputs = pipeline.generate(test_example)

algorithm = RepeReadingControl()
algorithm.run(pipeline, dataset)
new_outputs = pipeline.generate(example)
new_outputs = pipeline.generate(test_example)

# TODO: find a better assertion that ensures this is actually doing what it should
assert original_outputs != new_outputs

0 comments on commit 8eb3b75

Please sign in to comment.