diff --git a/repepo/algorithms/repe.py b/repepo/algorithms/repe.py index a8e4bf6b..f6476aba 100644 --- a/repepo/algorithms/repe.py +++ b/repepo/algorithms/repe.py @@ -1,12 +1,12 @@ from dataclasses import replace from typing import Any, Literal, NamedTuple, Optional, cast from typing_extensions import override -import numpy as np +import random import torch from repepo.core import Pipeline from repepo.core.format import Formatter -from repepo.core.types import Dataset, Model +from repepo.core.types import Dataset, Example, Model from repepo.repe.rep_reading_pipeline import RepReadingPipeline from repepo.algorithms.base import Algorithm from repepo.utils.layer_matching import ( @@ -18,6 +18,7 @@ from repepo.utils.model_patcher import LayerType, ModelPatcher DirectionMethod = Literal["pca", "cluster_mean", "random"] +MultiAnswerMethod = Literal["first_incorrect", "random_incorrect", "repeat_correct"] DEFAULT_READING_TEMPLATE = "Consider the amount of truthfulness in the following answer.\n\n{question} {answer}\n\nThe amount of truthfulness in the answer is " @@ -42,13 +43,13 @@ class RepeTrainingData(NamedTuple): class RepeDirections(NamedTuple): activations: dict[int, torch.Tensor] - # why is this a dict of np arrays instead of a dict of torch tensors? mystery - signs: dict[int, np.ndarray] + signs: dict[int, int] class RepeReadingControl(Algorithm): direction_method: DirectionMethod layer_type: LayerType + multi_answer_method: MultiAnswerMethod reading_template: str layers: list[int] | None n_difference: int @@ -56,22 +57,27 @@ class RepeReadingControl(Algorithm): max_length: int layer_config: ModelLayerConfig | None direction_finder_kwargs: dict[str, Any] + seed: int def __init__( self, reading_template: str = DEFAULT_READING_TEMPLATE, direction_method: DirectionMethod = "pca", layer_type: LayerType = "decoder_block", + multi_answer_method: MultiAnswerMethod = "first_incorrect", n_difference: int = 1, # TODO: what does this do? batch_size: int = 8, max_length: int = 2048, + seed: int = 0, layers: Optional[list[int]] = None, layer_config: Optional[ModelLayerConfig] = None, # TODO: remove this when refactoring repe reading pipeline direction_finder_kwargs: Optional[dict[str, Any]] = None, ): self.direction_method = direction_method + self.multi_answer_method = multi_answer_method self.layer_type = layer_type + self.seed = seed _validate_reading_template(reading_template) self.reading_template = reading_template self.layers = layers @@ -86,24 +92,16 @@ def _build_repe_training_data_and_labels( ) -> RepeTrainingData: prompts: list[str] = [] grouped_labels: list[list[int]] = [] - for example in dataset: - if example.incorrect_outputs is None: - raise ValueError( - "RepEngReadingControl requires incorrect_outputs to be set" - ) - incorrect_examples = [ - replace(example, output=output) for output in example.incorrect_outputs - ] - label_group = [1, *([0] * len(incorrect_examples))] - grouped_labels.append(label_group) - group_examples = [example, *incorrect_examples] - for group_example in group_examples: - completion = formatter.apply(group_example) - prompts.append( - self.reading_template.format( - question=completion.prompt, answer=completion.response - ) - ) + for i, example in enumerate(dataset): + example_prompts, example_labels = self._convert_example_to_repe_format( + example, + formatter, + # repe doesn't reflect differences across the origin, so we need to ensure + # an even balance of reversed and non-reversed examples + reverse_order=i % 2 == 0, + ) + prompts.extend(example_prompts) + grouped_labels.append(example_labels) return RepeTrainingData(prompts=prompts, labels=grouped_labels) def _get_layer_matcher_for_model(self, model: Model) -> LayerMatcher: @@ -154,7 +152,7 @@ def _get_directions(self, pipeline: Pipeline, dataset: Dataset) -> RepeDirection key: torch.FloatTensor(val) for key, val in rep_reader.directions.items() }, - signs=rep_reader.direction_signs, + signs={key: val.item() for key, val in rep_reader.direction_signs.items()}, ) @override @@ -167,3 +165,43 @@ def run(self, pipeline: Pipeline, dataset: Dataset) -> Pipeline: layer_type=self.layer_type, ) return pipeline + + def _convert_example_to_repe_format( + self, example: Example, formatter: Formatter, reverse_order: bool + ) -> tuple[list[str], list[int]]: + """Converts an example to the format expected by repe""" + if example.incorrect_outputs is None: + raise ValueError( + "RepEngReadingControl requires incorrect_outputs to be set" + ) + incorrect_examples = [ + replace(example, output=output) for output in example.incorrect_outputs + ] + correct_examples = [example] + if self.multi_answer_method == "repeat_correct": + correct_examples = [example] * len(example.incorrect_outputs) + elif self.multi_answer_method == "first_incorrect": + incorrect_examples = [incorrect_examples[0]] + elif self.multi_answer_method == "random_incorrect": + rand_gen = random.Random(f"{self.seed}-{example.input}") + incorrect_examples = [rand_gen.choice(incorrect_examples)] + else: + raise ValueError(f"Unknown multi_answer_method {self.multi_answer_method}") + assert len(incorrect_examples) == len(correct_examples) + # repe (undocumentedly) expects interleaved positive and negative examples + if reverse_order: + labels = [0, 1] * len(correct_examples) + paired_examples = zip(incorrect_examples, correct_examples) + else: + labels = [1, 0] * len(correct_examples) + paired_examples = zip(correct_examples, incorrect_examples) + # interleaved pos, neg, pos neg, ... + examples = [ex for pair in paired_examples for ex in pair] + completions = [formatter.apply(ex) for ex in examples] + prompts = [ + self.reading_template.format( + question=completion.prompt, answer=completion.response + ) + for completion in completions + ] + return prompts, labels diff --git a/tests/algorithms/__snapshots__/test_repe.ambr b/tests/algorithms/__snapshots__/test_repe.ambr index af294e08..dc97cf67 100644 --- a/tests/algorithms/__snapshots__/test_repe.ambr +++ b/tests/algorithms/__snapshots__/test_repe.ambr @@ -1,14 +1,6 @@ # serializer version: 1 -# name: test_RepeReadingControl_build_repe_training_data_and_labels +# name: test_RepeReadingControl_build_repe_training_data_and_labels_picks_one_neg_by_default list([ - ''' - Consider the amount of truthfulness in the following answer. - - Input: Paris is in - Output: France - - The amount of truthfulness in the answer is - ''', ''' Consider the amount of truthfulness in the following answer. @@ -21,7 +13,7 @@ Consider the amount of truthfulness in the following answer. Input: Paris is in - Output: Italy + Output: France The amount of truthfulness in the answer is ''', @@ -39,22 +31,6 @@ Input: 1 + 1 = Output: 11 - The amount of truthfulness in the answer is - ''', - ''' - Consider the amount of truthfulness in the following answer. - - Input: 1 + 1 = - Output: 1234 - - The amount of truthfulness in the answer is - ''', - ''' - Consider the amount of truthfulness in the following answer. - - Input: 1 + 1 = - Output: 3.14 - The amount of truthfulness in the answer is ''', ]) diff --git a/tests/algorithms/test_repe.py b/tests/algorithms/test_repe.py index e192bab1..14807f3b 100644 --- a/tests/algorithms/test_repe.py +++ b/tests/algorithms/test_repe.py @@ -6,7 +6,7 @@ from transformers import GPTNeoXForCausalLM -def test_RepeReadingControl_build_repe_training_data_and_labels( +def test_RepeReadingControl_build_repe_training_data_and_labels_picks_one_neg_by_default( snapshot: SnapshotAssertion, ) -> None: dataset: Dataset = [ @@ -20,7 +20,7 @@ def test_RepeReadingControl_build_repe_training_data_and_labels( instruction="", input="1 + 1 =", output="2", - incorrect_outputs=["11", "1234", "3.14"], + incorrect_outputs=["11", "34", "3.14"], ), ] formatter = InputOutputFormatter() @@ -29,11 +29,87 @@ def test_RepeReadingControl_build_repe_training_data_and_labels( dataset, formatter ) # for some reason the training data isn't grouped, but labels are. This is how it is in the original code. - assert len(training_data) == 7 - assert labels == [[1, 0, 0], [1, 0, 0, 0]] + assert len(training_data) == 4 + # should pick the first incorrect output only by default + assert "Germany" in training_data[0] + assert "France" in training_data[1] + assert "2" in training_data[2] + assert "11" in training_data[3] + # should alternate between flipped and non-flipped labels + assert labels == [[0, 1], [1, 0]] assert training_data == snapshot +def test_RepeReadingControl_build_repe_training_data_and_labels_with_random_incorrect() -> ( + None +): + dataset: Dataset = [ + Example( + instruction="", + input="Paris is in", + output="France", + incorrect_outputs=["Germany", "Italy"], + ), + Example( + instruction="", + input="1 + 1 =", + output="2", + incorrect_outputs=["11", "34", "3.14"], + ), + ] + formatter = InputOutputFormatter() + algorithm = RepeReadingControl() + training_data, labels = algorithm._build_repe_training_data_and_labels( + dataset, formatter + ) + # for some reason the training data isn't grouped, but labels are. This is how it is in the original code. + assert len(training_data) == 4 + # should pick the a random incorrect output + assert "France" in training_data[1] + assert "2" in training_data[2] + # should alternate between flipped and non-flipped labels + assert labels == [[0, 1], [1, 0]] + + +def test_RepeReadingControl_build_repe_training_data_and_labels_with_repeat_correct() -> ( + None +): + dataset: Dataset = [ + Example( + instruction="", + input="Paris is in", + output="France", + incorrect_outputs=["Germany", "Italy"], + ), + Example( + instruction="", + input="1 + 1 =", + output="2", + incorrect_outputs=["11", "34", "3.14"], + ), + ] + formatter = InputOutputFormatter() + algorithm = RepeReadingControl(multi_answer_method="repeat_correct") + training_data, labels = algorithm._build_repe_training_data_and_labels( + dataset, formatter + ) + # for some reason the training data isn't grouped, but labels are. This is how it is in the original code. + assert len(training_data) == 10 + # the positive example should be repeated once for each incorrect output + assert "Germany" in training_data[0] + assert "France" in training_data[1] + assert "Italy" in training_data[2] + assert "France" in training_data[3] + assert "2" in training_data[4] + assert "11" in training_data[5] + assert "2" in training_data[6] + assert "34" in training_data[7] + assert "2" in training_data[8] + assert "3.14" in training_data[9] + # should alternate between flipped and non-flipped labels + assert labels == [[0, 1, 0, 1], [1, 0, 1, 0, 1, 0]] + + def test_RepeReadingControl_get_directions( model: GPTNeoXForCausalLM, tokenizer: Tokenizer ) -> None: @@ -47,12 +123,14 @@ def test_RepeReadingControl_get_directions( incorrect_outputs=["Germany", "Italy"], ), ] - algorithm = RepeReadingControl() + algorithm = RepeReadingControl(multi_answer_method="repeat_correct") directions = algorithm._get_directions(pipeline, dataset) assert list(directions.activations.keys()) == [-1, -2, -3, -4, -5] assert list(directions.signs.keys()) == [-1, -2, -3, -4, -5] for act in directions.activations.values(): assert act.shape == (1, 512) + for sign in directions.signs.values(): + assert sign in [-1, 1] def test_RepeReadingControl_run( @@ -61,19 +139,27 @@ def test_RepeReadingControl_run( tokenizer.pad_token_id = model.config.eos_token_id pipeline = Pipeline(model, tokenizer) - example = Example( + test_example = Example( instruction="", input="Paris is in", output="France", incorrect_outputs=["Germany", "Italy"], ) - dataset: Dataset = [example] + dataset: Dataset = [ + test_example, + Example( + instruction="", + input="1 + 1 =", + output="2", + incorrect_outputs=["11", "34", "3.14"], + ), + ] - original_outputs = pipeline.generate(example) + original_outputs = pipeline.generate(test_example) algorithm = RepeReadingControl() algorithm.run(pipeline, dataset) - new_outputs = pipeline.generate(example) + new_outputs = pipeline.generate(test_example) # TODO: find a better assertion that ensures this is actually doing what it should assert original_outputs != new_outputs