Skip to content

Commit a59f2cf

Browse files
qgallouedecbehroozazarkhalilikashif
authored
Move WinRateCallback to experimental (#4558)
Co-authored-by: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com> Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
1 parent cf431db commit a59f2cf

File tree

8 files changed

+533
-396
lines changed

8 files changed

+533
-396
lines changed

docs/source/_toctree.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@
8787
- sections:
8888
- local: experimental_overview
8989
title: Experimental Overview
90+
- local: openenv
91+
title: OpenEnv Integration
9092
- local: bema_for_reference_model # Sorted alphabetically
9193
title: BEMA for Reference Model
9294
- local: bco_trainer
@@ -119,8 +121,8 @@
119121
title: PPO
120122
- local: prm_trainer
121123
title: PRM
124+
- local: winrate_callback
125+
title: WinRateCallback
122126
- local: xpo_trainer
123127
title: XPO
124-
- local: openenv
125-
title: OpenEnv Integration
126128
title: Experimental

docs/source/callbacks.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,6 @@
88

99
[[autodoc]] RichProgressCallback
1010

11-
## WinRateCallback
12-
13-
[[autodoc]] WinRateCallback
14-
1511
## LogCompletionsCallback
1612

1713
[[autodoc]] LogCompletionsCallback

docs/source/winrate_callback.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# WinRateCallback
2+
3+
[[autodoc]] experimental.winrate_callback.WinRateCallback
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from datasets import load_dataset
17+
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, Trainer, TrainingArguments
18+
from transformers.utils import is_peft_available
19+
20+
from trl.experimental.judges import BasePairwiseJudge
21+
from trl.experimental.winrate_callback import WinRateCallback
22+
23+
from ..testing_utils import TrlTestCase, require_peft
24+
25+
26+
if is_peft_available():
27+
from peft import LoraConfig
28+
29+
30+
class HalfPairwiseJudge(BasePairwiseJudge):
31+
"""Naive pairwise judge that always returns [1, 0] for two prompts"""
32+
33+
def judge(self, prompts, completions, shuffle_order=True, return_scores=False):
34+
# just check that the batch size is 2
35+
assert len(prompts) == 2
36+
if return_scores:
37+
return [0.3, 0.9]
38+
return [1, 0]
39+
40+
41+
class TrainerWithRefModel(Trainer):
42+
# This is a dummy class to test the callback. Compared to the Trainer class, it only has an additional
43+
# ref_model attribute
44+
def __init__(self, model, ref_model, args, train_dataset, eval_dataset, processing_class):
45+
super().__init__(
46+
model=model,
47+
args=args,
48+
train_dataset=train_dataset,
49+
eval_dataset=eval_dataset,
50+
processing_class=processing_class,
51+
)
52+
self.ref_model = ref_model
53+
54+
55+
class TestWinRateCallback(TrlTestCase):
56+
def setup_method(self):
57+
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
58+
self.ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
59+
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
60+
self.tokenizer.pad_token = self.tokenizer.eos_token
61+
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")
62+
dataset["train"] = dataset["train"].select(range(8))
63+
self.expected_winrates = [
64+
{"eval_win_rate": 0.5, "epoch": 0.0, "step": 0},
65+
{"eval_win_rate": 0.5, "epoch": 0.5, "step": 2},
66+
{"eval_win_rate": 0.5, "epoch": 1.0, "step": 4},
67+
{"eval_win_rate": 0.5, "epoch": 1.5, "step": 6},
68+
{"eval_win_rate": 0.5, "epoch": 2.0, "step": 8},
69+
{"eval_win_rate": 0.5, "epoch": 2.5, "step": 10},
70+
{"eval_win_rate": 0.5, "epoch": 3.0, "step": 12},
71+
]
72+
73+
def tokenize_function(examples):
74+
out = self.tokenizer(examples["prompt"], padding="max_length", max_length=16, truncation=True)
75+
out["labels"] = out["input_ids"].copy()
76+
return out
77+
78+
self.dataset = dataset.map(tokenize_function, batched=True)
79+
80+
self.generation_config = GenerationConfig(max_length=32)
81+
self.judge = HalfPairwiseJudge()
82+
83+
def test_basic(self):
84+
training_args = TrainingArguments(
85+
output_dir=self.tmp_dir,
86+
eval_strategy="steps",
87+
eval_steps=2, # evaluate every 2 steps
88+
per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch
89+
per_device_eval_batch_size=2,
90+
report_to="none",
91+
)
92+
trainer = TrainerWithRefModel(
93+
model=self.model,
94+
ref_model=self.ref_model,
95+
args=training_args,
96+
train_dataset=self.dataset["train"],
97+
eval_dataset=self.dataset["test"],
98+
processing_class=self.tokenizer,
99+
)
100+
win_rate_callback = WinRateCallback(
101+
judge=self.judge, trainer=trainer, generation_config=self.generation_config
102+
)
103+
trainer.add_callback(win_rate_callback)
104+
trainer.train()
105+
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h]
106+
for history_row, expected_row in zip(winrate_history, self.expected_winrates, strict=True):
107+
assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)
108+
109+
def test_without_ref_model(self):
110+
# Same as before, but without the ref_model attribute. It should use the model attribute instead
111+
training_args = TrainingArguments(
112+
output_dir=self.tmp_dir,
113+
eval_strategy="steps",
114+
eval_steps=2, # evaluate every 2 steps
115+
per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch
116+
per_device_eval_batch_size=2,
117+
report_to="none",
118+
)
119+
trainer = Trainer(
120+
model=self.model,
121+
args=training_args,
122+
train_dataset=self.dataset["train"],
123+
eval_dataset=self.dataset["test"],
124+
processing_class=self.tokenizer,
125+
)
126+
win_rate_callback = WinRateCallback(
127+
judge=self.judge, trainer=trainer, generation_config=self.generation_config
128+
)
129+
trainer.add_callback(win_rate_callback)
130+
trainer.train()
131+
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h]
132+
for history_row, expected_row in zip(winrate_history, self.expected_winrates, strict=True):
133+
assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)
134+
135+
def test_soft_judge(self):
136+
"""Test that the soft judge functionality works correctly"""
137+
training_args = TrainingArguments(
138+
output_dir=self.tmp_dir,
139+
eval_strategy="steps",
140+
eval_steps=2, # evaluate every 2 steps
141+
per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch
142+
per_device_eval_batch_size=2,
143+
report_to="none",
144+
)
145+
trainer = TrainerWithRefModel(
146+
model=self.model,
147+
ref_model=self.ref_model,
148+
args=training_args,
149+
train_dataset=self.dataset["train"],
150+
eval_dataset=self.dataset["test"],
151+
processing_class=self.tokenizer,
152+
)
153+
win_rate_callback = WinRateCallback(
154+
judge=self.judge, trainer=trainer, generation_config=self.generation_config, use_soft_judge=True
155+
)
156+
trainer.add_callback(win_rate_callback)
157+
trainer.train()
158+
159+
# Expected values based on judge returning [0.3, 0.9] for each pair
160+
expected_soft_winrates = [
161+
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 0.0, "step": 0},
162+
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 0.5, "step": 2},
163+
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 1.0, "step": 4},
164+
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 1.5, "step": 6},
165+
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 2.0, "step": 8},
166+
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 2.5, "step": 10},
167+
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 3.0, "step": 12},
168+
]
169+
170+
winrate_history = [
171+
{k: h[k] for k in ["eval_avg_win_prob", "eval_win_rate", "epoch", "step"]}
172+
for h in trainer.state.log_history
173+
if "eval_avg_win_prob" in h
174+
]
175+
for history_row, expected_row in zip(winrate_history, expected_soft_winrates, strict=True):
176+
assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)
177+
178+
@require_peft
179+
def test_lora(self):
180+
peft_config = LoraConfig(
181+
r=16,
182+
lora_alpha=32,
183+
lora_dropout=0.05,
184+
bias="none",
185+
task_type="CAUSAL_LM",
186+
)
187+
self.model.add_adapter(peft_config)
188+
training_args = TrainingArguments(
189+
output_dir=self.tmp_dir,
190+
eval_strategy="steps",
191+
eval_steps=2, # evaluate every 2 steps
192+
per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch
193+
per_device_eval_batch_size=2,
194+
report_to="none",
195+
)
196+
trainer = Trainer(
197+
model=self.model,
198+
args=training_args,
199+
train_dataset=self.dataset["train"],
200+
eval_dataset=self.dataset["test"],
201+
processing_class=self.tokenizer,
202+
)
203+
win_rate_callback = WinRateCallback(
204+
judge=self.judge, trainer=trainer, generation_config=self.generation_config
205+
)
206+
trainer.add_callback(win_rate_callback)
207+
trainer.train()
208+
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h]
209+
for history_row, expected_row in zip(winrate_history, self.expected_winrates, strict=True):
210+
assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)

0 commit comments

Comments
 (0)