Skip to content

Commit

Permalink
Refactoring formatting and benchmarking to support CAA (#87)
Browse files Browse the repository at this point in the history
* refactoring formatting and benchmarking to support CAA

* adding a basic test for get_normalized_correct_probs()

* fixing tests

* increasing sft loss threshold to make test less flaky
  • Loading branch information
chanind authored Jan 31, 2024
1 parent c1fb281 commit c98f067
Show file tree
Hide file tree
Showing 18 changed files with 621 additions and 274 deletions.
2 changes: 1 addition & 1 deletion repepo/algorithms/icl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ def __init__(self, max_icl_examples: int = 5):
def run(self, pipeline: Pipeline, dataset: Dataset) -> dict[str, Any]:
"""Uses an in-context learning prefix to prompts"""
icl_msgs = dataset[: self.max_icl_examples]
pipeline.conversation_wrapper.conversation_history = icl_msgs
pipeline.conversation_history = icl_msgs
return {}
33 changes: 14 additions & 19 deletions repepo/algorithms/repe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@

MultiAnswerMethod = Literal["first_incorrect", "random_incorrect", "repeat_correct"]

REPE_READING_TEMPLATE = "Consider the amount of truthfulness in the following answer.\n\n{question} {answer}\n\nThe amount of truthfulness in the answer is "
CAA_READING_TEMPLATE = "<s> {question}{answer}"
REPE_READING_TEMPLATE = "Consider the amount of truthfulness in the following answer.\n\n{prompt}\n\nThe amount of truthfulness in the answer is "
# Sample prompt from CAA:

# <s> [INST] <<SYS>>
Expand All @@ -35,17 +34,13 @@
# (A) Agree
# (B) Disagree [/INST] (A)

DEFAULT_READING_TEMPLATE = CAA_READING_TEMPLATE
DEFAULT_READING_TEMPLATE = "{prompt}"


def _validate_reading_template(reading_template: str):
if "{answer}" not in reading_template:
if "{prompt}" not in reading_template:
raise ValueError(
"reading_template must contain {answer} to be used with RepEngReadingControl"
)
if "{question}" not in reading_template:
raise ValueError(
"reading_template must contain {question} to be used with RepEngReadingControl"
"reading_template must contain {prompt} to be used with RepEngReadingControl"
)


Expand Down Expand Up @@ -101,6 +96,7 @@ class RepeReadingControl(Algorithm):
skip_first_n_generation_tokens: int
read_token_index: int
seed: int
show_progress: bool

def __init__(
self,
Expand All @@ -120,6 +116,7 @@ def __init__(
# make sure to set this to -2 when working with CAA's data format
# Reference: https://github.com/nrimsky/SycophancySteering/blob/25f93a1f1aad51f94288f52d01f6a10d10f42bf1/generate_vectors.py#L102C13-L102C67
read_token_index: int = -1,
show_progress: bool = True,
):
self.multi_answer_method = multi_answer_method
self.layer_type = layer_type
Expand All @@ -132,6 +129,7 @@ def __init__(
self.read_token_index = read_token_index
self.layer_config = layer_config
self.direction_multiplier = direction_multiplier
self.show_progress = show_progress

self.skip_reading = skip_reading
self.override_vector = override_vector
Expand Down Expand Up @@ -175,6 +173,7 @@ def _get_steering_vector(
layer_config=self.layer_config,
move_to_cpu=True,
read_token_index=self.read_token_index,
show_progress=self.show_progress,
)

@override
Expand Down Expand Up @@ -233,23 +232,19 @@ def _convert_example_to_training_samples(
else:
raise ValueError(f"Unknown multi_answer_method {self.multi_answer_method}")
assert len(incorrect_examples) == len(correct_examples)
paired_completions = [
paired_prompts = [
(
pipeline.build_completion(pos),
pipeline.build_completion(neg),
pipeline.build_full_prompt(pos),
pipeline.build_full_prompt(neg),
)
for pos, neg in zip(correct_examples, incorrect_examples)
]
return [
SteeringVectorTrainingSample(
positive_prompt=self.reading_template.format(
question=pos.prompt, answer=pos.response
),
negative_prompt=self.reading_template.format(
question=neg.prompt, answer=neg.response
),
positive_prompt=self.reading_template.format(prompt=pos),
negative_prompt=self.reading_template.format(prompt=neg),
)
for pos, neg in paired_completions
for pos, neg in paired_prompts
]


Expand Down
4 changes: 1 addition & 3 deletions repepo/baselines/sft/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@ def make_supervised_data_module(
"""Make dataset and collator for supervised fine-tuning."""
examples: list[Example] = get_dataset(data_args.dataset_name)
fmt = InstructionFormatter()
completions: list[Completion] = [
fmt.format_conversation([ex])[0] for ex in examples
]
completions: list[Completion] = [fmt.format_conversation(ex) for ex in examples]
train_dataset = sft.SupervisedDataset(completions, tokenizer=tokenizer)
data_collator = sft.DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return train_dataset, data_collator
Expand Down
104 changes: 61 additions & 43 deletions repepo/core/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
# pyright: strict, reportMissingTypeStubs=false

from dataclasses import dataclass, replace
from dataclasses import dataclass
from typing import Optional, Sequence

from transformers.generation import GenerationConfig

from repepo.algorithms.base import Algorithm
from repepo.core.conversation_wrapper import ConversationWrapper
from repepo.core.evaluate import EvalPrediction, EvalResult, Evaluator
from repepo.core.evaluate import (
EvalHook,
EvalResult,
Evaluator,
evaluate,
)
from repepo.core.format import Formatter, InputOutputFormatter
from repepo.core.pipeline import Pipeline

Expand All @@ -22,59 +26,73 @@ class Benchmark:
evaluators: list[Evaluator]


def train_and_evaluate_benchmark(
def train_benchmark(
model: Model,
tokenizer: Tokenizer,
algorithms: Sequence[Algorithm],
benchmark: Benchmark,
formatter: Optional[Formatter] = None,
conversation_wrapper: Optional[ConversationWrapper] = None,
generation_config: Optional[GenerationConfig] = None,
) -> EvalResult:
) -> Pipeline:
pipeline = Pipeline(
model,
tokenizer,
formatter=formatter or InputOutputFormatter(),
conversation_wrapper=conversation_wrapper or ConversationWrapper(),
)

# train pipeline
for algorithm in algorithms:
# Re-initialize pipeline, which gets destructively modified
# TODO: do something with outputs?
_ = algorithm.run(pipeline, benchmark.train_dataset)
return pipeline


def evaluate_benchmark(
pipeline: Pipeline,
benchmark: Benchmark,
generation_config: Optional[GenerationConfig] = None,
# these eval_hooks allow us to do custom stuff to the pipeline only during evaluation,
# e.g. mess with the formatter to use CAA's special answer format
eval_hooks: list[EvalHook] = [],
show_progress: bool = True,
tqdm_desc: str = "Evaluating",
) -> EvalResult:
# evaluate
predictions: list[EvalPrediction] = []
requires_generation = any([e.requires_generation for e in benchmark.evaluators])
requires_probs = any([e.requires_probs for e in benchmark.evaluators])
# TODO: support batching
for example in benchmark.test_dataset:
generated_output = None
correct_output_probs = None
incorrect_outputs_probs = None
if requires_generation:
generated_output = pipeline.generate(
example, generation_config=generation_config
)
if requires_probs:
correct_output_probs = pipeline.calculate_output_logprobs(example)
if example.incorrect_outputs is not None:
incorrect_outputs_probs = [
pipeline.calculate_output_logprobs(
replace(example, output=incorrect_output)
)
for incorrect_output in example.incorrect_outputs
]
predictions.append(
EvalPrediction(
example=example,
generated_output=generated_output,
correct_output_probs=correct_output_probs,
incorrect_outputs_probs=incorrect_outputs_probs,
)
)
metrics: dict[str, float] = {}
for evaluator in benchmark.evaluators:
metrics.update(evaluator(predictions))
return EvalResult(predictions, metrics)
return evaluate(
pipeline,
dataset=benchmark.test_dataset,
evaluators=benchmark.evaluators,
generation_config=generation_config,
eval_hooks=eval_hooks,
show_progress=show_progress,
tqdm_desc=tqdm_desc,
)


def train_and_evaluate_benchmark(
model: Model,
tokenizer: Tokenizer,
algorithms: Sequence[Algorithm],
benchmark: Benchmark,
formatter: Optional[Formatter] = None,
generation_config: Optional[GenerationConfig] = None,
# these eval_hooks allow us to do custom stuff to the pipeline only during evaluation,
# e.g. mess with the formatter to use CAA's special answer format
eval_hooks: list[EvalHook] = [],
show_progress: bool = True,
) -> EvalResult:
# train
pipeline = train_benchmark(
model,
tokenizer,
algorithms,
benchmark,
formatter=formatter,
)

# evaluate
return evaluate_benchmark(
pipeline,
benchmark,
generation_config=generation_config,
eval_hooks=eval_hooks,
show_progress=show_progress,
)
32 changes: 0 additions & 32 deletions repepo/core/conversation_wrapper.py

This file was deleted.

Loading

0 comments on commit c98f067

Please sign in to comment.