Skip to content

Commit

Permalink
refactoring prompting/formatting (#77)
Browse files Browse the repository at this point in the history
* refactoring prompting/formatting

* fixing conflict in tests
  • Loading branch information
chanind authored Jan 24, 2024
1 parent f86d0d2 commit dca53ac
Show file tree
Hide file tree
Showing 17 changed files with 275 additions and 188 deletions.
10 changes: 2 additions & 8 deletions repepo/algorithms/icl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from repepo.core.prompt import FewShotPrompter
from .base import Algorithm
from repepo.core import Pipeline, Dataset
from typing import Any
Expand All @@ -13,11 +12,6 @@ def __init__(self, max_icl_examples: int = 5):

def run(self, pipeline: Pipeline, dataset: Dataset) -> dict[str, Any]:
"""Uses an in-context learning prefix to prompts"""

icl_completions = pipeline.formatter.apply_list(
dataset[: self.max_icl_examples]
)
new_prompter = FewShotPrompter(icl_completions)
pipeline.prompter = new_prompter

icl_msgs = dataset[: self.max_icl_examples]
pipeline.conversation_wrapper.conversation_history = icl_msgs
return {}
14 changes: 8 additions & 6 deletions repepo/algorithms/repe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
)

from repepo.core import Pipeline
from repepo.core.format import Formatter
from repepo.core.types import Dataset, Example, Model, Tokenizer
from repepo.algorithms.base import Algorithm

Expand Down Expand Up @@ -129,12 +128,12 @@ def __init__(
)

def _build_steering_vector_training_data(
self, dataset: Dataset, formatter: Formatter
self, dataset: Dataset, pipeline: Pipeline
) -> list[SteeringVectorTrainingSample]:
paired_prompts: list[SteeringVectorTrainingSample] = []
for example in dataset:
example_prompts = self._convert_example_to_training_samples(
example, formatter
example, pipeline
)
paired_prompts.extend(example_prompts)
return paired_prompts
Expand All @@ -151,7 +150,7 @@ def _get_steering_vector(
self, pipeline: Pipeline, dataset: Dataset
) -> SteeringVector:
repe_training_data = self._build_steering_vector_training_data(
dataset, pipeline.formatter
dataset, pipeline
)
return train_steering_vector(
pipeline.model,
Expand Down Expand Up @@ -201,7 +200,7 @@ def run(self, pipeline: Pipeline, dataset: Dataset) -> Pipeline:
return pipeline

def _convert_example_to_training_samples(
self, example: Example, formatter: Formatter
self, example: Example, pipeline: Pipeline
) -> list[SteeringVectorTrainingSample]:
"""Converts an example to the format expected by steering-vectors"""
if example.incorrect_outputs is None:
Expand All @@ -221,7 +220,10 @@ def _convert_example_to_training_samples(
raise ValueError(f"Unknown multi_answer_method {self.multi_answer_method}")
assert len(incorrect_examples) == len(correct_examples)
paired_completions = [
(formatter.apply(pos), formatter.apply(neg))
(
pipeline.build_completion(pos),
pipeline.build_completion(neg),
)
for pos, neg in zip(correct_examples, incorrect_examples)
]
return [
Expand Down
4 changes: 3 additions & 1 deletion repepo/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def run(
tokenizer.pad_token = tokenizer.eos_token

# Make dataset
completions: List[Completion] = pipeline.formatter.apply_list(dataset)
completions: List[Completion] = [
pipeline.build_completion(ex) for ex in dataset
]
train_dataset = sft.SupervisedDataset(completions, tokenizer=tokenizer)
data_collator = sft.DataCollatorForSupervisedDataset(tokenizer=tokenizer)

Expand Down
5 changes: 4 additions & 1 deletion repepo/baselines/sft/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ def make_supervised_data_module(
) -> tuple[sft.SupervisedDataset, sft.DataCollatorForSupervisedDataset]:
"""Make dataset and collator for supervised fine-tuning."""
examples: list[Example] = get_dataset(data_args.dataset_name)
completions: list[Completion] = InstructionFormatter().apply_list(examples)
fmt = InstructionFormatter()
completions: list[Completion] = [
fmt.format_conversation([ex])[0] for ex in examples
]
train_dataset = sft.SupervisedDataset(completions, tokenizer=tokenizer)
data_collator = sft.DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return train_dataset, data_collator
Expand Down
9 changes: 8 additions & 1 deletion repepo/core/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from transformers.generation import GenerationConfig

from repepo.algorithms.base import Algorithm
from repepo.core.conversation_wrapper import ConversationWrapper
from repepo.core.evaluate import EvalPrediction, EvalResult, Evaluator
from repepo.core.format import Formatter, InputOutputFormatter
from repepo.core.pipeline import Pipeline
Expand All @@ -27,9 +28,15 @@ def train_and_evaluate_benchmark(
algorithms: Sequence[Algorithm],
benchmark: Benchmark,
formatter: Optional[Formatter] = None,
conversation_wrapper: Optional[ConversationWrapper] = None,
generation_config: Optional[GenerationConfig] = None,
) -> EvalResult:
pipeline = Pipeline(model, tokenizer, formatter=formatter or InputOutputFormatter())
pipeline = Pipeline(
model,
tokenizer,
formatter=formatter or InputOutputFormatter(),
conversation_wrapper=conversation_wrapper or ConversationWrapper(),
)

# train pipeline
for algorithm in algorithms:
Expand Down
32 changes: 32 additions & 0 deletions repepo/core/conversation_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from dataclasses import dataclass, field

from .types import Completion, Example
from .format import Formatter


def completion_to_str(completion: Completion) -> str:
return completion.prompt.rstrip() + " " + completion.response.lstrip()


@dataclass
class ConversationWrapper:
"""Wraps a conversation, summarizing a list of messages into a single completion."""

template: str = "{conversation}"
conversation_history: list[Example] = field(default_factory=list)
msg_separator: str = "\n"

def wrap(self, formatter: Formatter, message: Example) -> Completion:
completions = formatter.format_conversation(
[*self.conversation_history, message]
)
prefix_completions = completions[:-1]
final_completion = completions[-1]
convo_prefix = self.msg_separator.join(
[completion_to_str(c) for c in prefix_completions]
)
convo_str = final_completion.prompt
if len(convo_prefix) > 0:
convo_str = convo_prefix + self.msg_separator + final_completion.prompt
prompt = self.template.format(conversation=convo_str)
return Completion(prompt=prompt, response=final_completion.response)
49 changes: 40 additions & 9 deletions repepo/core/format.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,42 @@
import abc
from dataclasses import dataclass
from typing import List
from typing_extensions import override

from repepo.core.types import Completion
from repepo.core.types import Example


@dataclass
class FormatContext:
"""
Context provided to the formatter in the format() method.
"""

index: int
examples: List[Example]

@property
def num_examples(self) -> int:
return len(self.examples)


class Formatter(abc.ABC):
"""Describes how to format examples as completions"""

@abc.abstractmethod
def apply(self, example: Example, **kwargs) -> Completion:
def format(self, example: Example, ctx: FormatContext) -> Completion:
"""
Format an example as a completion. Subclasses should override this method.
This method should not be called directly externally, instead use format_conversation().
"""
pass

def apply_list(self, examples: List[Example]) -> List[Completion]:
def format_conversation(self, conversation: List[Example]) -> List[Completion]:
completions: list[Completion] = []
for example in examples:
completion = self.apply(example)
for i, example in enumerate(conversation):
ctx = FormatContext(index=i, examples=conversation)
completion = self.format(example, ctx)
completions.append(completion)
return completions

Expand All @@ -27,8 +47,7 @@ class InputOutputFormatter(Formatter):
PROMPT_TEMPLATE = "Input: {instruction} {input} \n" "Output: "

@override
def apply(self, example: Example, **kwargs):
del kwargs
def format(self, example: Example, ctx: FormatContext):
return Completion(
prompt=self.PROMPT_TEMPLATE.format(
instruction=example.instruction, input=example.input
Expand All @@ -40,16 +59,29 @@ def apply(self, example: Example, **kwargs):
class LlamaChatFormatter(Formatter):
"""
Add [INST] and [/INST] tags to the instruction and input.
Also adds a system message before the first prompt.
Based on: https://github.com/nrimsky/SycophancySteering/blob/main/utils/tokenize_llama.py#L30
"""

system_prompt: str | None

B_INST = "[INST]"
E_INST = "[/INST]"
B_SYS = "<<SYS>>\n"
E_SYS = "\n<</SYS>>\n"

def __init__(
self,
system_prompt: str | None = "You are a helpful, honest and concise assistant.",
) -> None:
self.system_prompt = system_prompt

@override
def apply(self, example: Example):
def format(self, example: Example, ctx: FormatContext):
dialog_content_parts = []
if ctx.index == 0 and self.system_prompt is not None:
dialog_content_parts.append(f"{self.B_SYS}{self.system_prompt}{self.E_SYS}")
if example.instruction:
dialog_content_parts.append(example.instruction.strip())
dialog_content_parts.append(example.input.strip())
Expand All @@ -74,8 +106,7 @@ class InstructionFormatter(Formatter):
)

@override
def apply(self, example: Example, **kwargs):
del kwargs
def format(self, example: Example, ctx: FormatContext):
if bool(example.input):
prompt = self.PROMPT_INPUT.format(
instruction=example.instruction, input=example.input
Expand Down
25 changes: 17 additions & 8 deletions repepo/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from transformers.generation import GenerationConfig
import torch

from .types import Example, Model, Tokenizer
from .prompt import Prompter, IdentityPrompter
from .types import Completion, Example, Model, Tokenizer
from .conversation_wrapper import ConversationWrapper
from .format import Formatter, InputOutputFormatter


Expand Down Expand Up @@ -48,15 +48,18 @@ class Pipeline:

model: Model
tokenizer: Tokenizer
prompter: Prompter = field(default_factory=IdentityPrompter)
formatter: Formatter = field(default_factory=InputOutputFormatter)
conversation_wrapper: ConversationWrapper = field(
default_factory=ConversationWrapper
)
hooks: list[PipelineHook] = field(default_factory=list)

def build_generation_prompt(self, example: Example) -> str:
"""Build a prompt for generation"""
completion = self.formatter.apply(example)
completion = self.prompter.apply(completion)
return completion.prompt
return self.build_completion(example).prompt.rstrip()

def build_completion(self, example: Example) -> Completion:
return self.conversation_wrapper.wrap(self.formatter, example)

def generate(
self,
Expand Down Expand Up @@ -90,8 +93,9 @@ def generate(

def calculate_output_logprobs(self, example: Example) -> TextProbs:
"""Calculate the logprobs for each token in the prompt + output"""
base_prompt = self.build_generation_prompt(example)
full_prompt = base_prompt + example.output
completion = self.build_completion(example)
base_prompt = completion.prompt
full_prompt = _build_full_prompt(completion)
inputs: Any = self.tokenizer(full_prompt, return_tensors="pt")
inputs = inputs.to(self.model.device)
context = PipelineContext(
Expand Down Expand Up @@ -122,3 +126,8 @@ def calculate_output_logprobs(self, example: Example) -> TextProbs:
)
return TextProbs(text=full_prompt, token_probs=text_probs)
raise RuntimeError("Should never get here")


def _build_full_prompt(completion: Completion) -> str:
"""Build a prompt for generation"""
return completion.prompt.rstrip() + " " + completion.response.lstrip()
75 changes: 0 additions & 75 deletions repepo/core/prompt.py

This file was deleted.

Loading

0 comments on commit dca53ac

Please sign in to comment.