Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove prompts arg from WinrateCallback #2010

Merged
merged 4 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions tests/test_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import tempfile

from datasets import Dataset, DatasetDict
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, Trainer, TrainingArguments

from trl import BasePairwiseJudge, WinRateCallback


class ThreeQuatersPairwiseJudge(BasePairwiseJudge):
"""Naive pairwise judge that always returns [1, 0, 1, 1, 0, 1, 1, 1]"""

def judge(self, prompts, completions, shuffle_order=True):
# just check that the batch size is 4
assert len(prompts) == 8
return [1, 0, 1, 1, 0, 1, 1, 1]


class TrainerWithRefModel(Trainer):
# This is a dummy class to test the callback. Compared to the Trainer class, it only has an additional
# ref_model attribute
def __init__(self, model, ref_model, args, trainer_dataset, eval_dataset, tokenizer):
super().__init__(
model=model, args=args, train_dataset=trainer_dataset, eval_dataset=eval_dataset, tokenizer=tokenizer
)
self.ref_model = ref_model


def test_trainer_callback():
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
tokenizer.pad_token = tokenizer.eos_token
dataset = DatasetDict(
{
"train": Dataset.from_dict(
{
"prompt": [
"Hello world!",
"This is a test.",
"We are creating a dataset.",
"It has eight lines.",
"Each line is a sentence.",
"The sentences are simple.",
"This is just for testing.",
"Goodbye!",
]
}
),
"test": Dataset.from_dict(
{
"prompt": [
"The sun sets in the west.",
"Mountains are majestic.",
"Rivers flow endlessly.",
"Forests are full of life.",
"Birds sing in the morning.",
"Waves crash on the shore.",
"The moon glows at night.",
"Stars twinkle in the sky.",
]
}
),
}
)

def tokenize_function(examples):
out = tokenizer(examples["prompt"], padding="max_length", max_length=16, truncation=True)
out["labels"] = out["input_ids"].copy()
return out

dataset = dataset.map(tokenize_function, batched=True)

with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(
output_dir=tmp_dir,
eval_strategy="steps",
eval_steps=2, # evaluate every 2 steps
per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch
per_device_eval_batch_size=2,
report_to="none",
)
trainer = TrainerWithRefModel(
model=model,
ref_model=ref_model,
args=args,
trainer_dataset=dataset["train"],
eval_dataset=dataset["test"],
tokenizer=tokenizer,
)
generation_config = GenerationConfig(max_length=32)
win_rate_callback = WinRateCallback(
judge=ThreeQuatersPairwiseJudge(), trainer=trainer, generation_config=generation_config
)
trainer.add_callback(win_rate_callback)
trainer.train()
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h]
assert winrate_history == [
{"eval_win_rate": 0.75, "epoch": 0.5, "step": 2},
{"eval_win_rate": 0.75, "epoch": 1.0, "step": 4},
{"eval_win_rate": 0.75, "epoch": 1.5, "step": 6},
{"eval_win_rate": 0.75, "epoch": 2.0, "step": 8},
{"eval_win_rate": 0.75, "epoch": 2.5, "step": 10},
{"eval_win_rate": 0.75, "epoch": 3.0, "step": 12},
]
24 changes: 16 additions & 8 deletions trl/trainer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from transformers.trainer_utils import has_length

from ..models.utils import unwrap_model_for_generation
from .judges import BaseRankJudge
from .judges import BasePairwiseJudge
from .utils import truncate_right


Expand Down Expand Up @@ -158,6 +158,10 @@ class WinRateCallback(TrainerCallback):
"""
A [`~transformers.TrainerCallback`] that computes the win rate of a model based on a reference.

It uses prompts from the evaluation dataset to generate completions for both the model and the reference model.
At every evaluation step, it compares the completions generated by the model and the reference model using a judge
and computes the win rate. This win rate is then logged to the trainer under the key `"eval_win_rate"`.

Usage:
```python
trainer = DPOTrainer(...)
Expand All @@ -166,12 +170,14 @@ class WinRateCallback(TrainerCallback):
```

Args:
prompts (`List[str]`):
The prompts to generate completions for.
judge (`BaseRankJudge`):
judge (`BasePairwiseJudge`):
The judge to use for comparing completions.
trainer (`Trainer`):
The trainer.
The trainer. The trainer must comply with the following requirements:

- its evaluation dataset must have a column `"prompt"` that contains the prompts to generate completions for.
- it must have an attribute `ref_model` that contains the reference model.

generation_config (`GenerationConfig`, *optional*):
The generation config to use for generating completions.
batch_size (`int`, *optional*):
Expand All @@ -180,13 +186,11 @@ class WinRateCallback(TrainerCallback):

def __init__(
self,
prompts: List[str],
judge: BaseRankJudge,
judge: BasePairwiseJudge,
trainer: Trainer,
generation_config: Optional[GenerationConfig] = None,
batch_size: int = 4,
):
self.prompts = prompts
self.generation_config = generation_config
self.judge = judge
self.ref_completions = []
Expand Down Expand Up @@ -217,13 +221,17 @@ def generate_completions_for_model(self, model, tokenizer, prompts):
return completions

def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
# When the trainer is initialized, we generate completions for the reference model.
tokenizer = kwargs["tokenizer"]
tokenizer.padding_side = "left"
accelerator = self.trainer.accelerator
with accelerator.split_between_processes(self.eval_dataset["prompt"], apply_padding=True) as prompts:
self.ref_completions = self.generate_completions_for_model(self.trainer.ref_model, tokenizer, prompts)

def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
# At every evaluation step, we generate completions for the model and compare them with the reference
# completions that have been generated at the beginning of training. We then compute the win rate and log it to
# the trainer.
model = kwargs["model"]
tokenizer = kwargs["tokenizer"]
accelerator = self.trainer.accelerator
Expand Down
Loading