Skip to content

Commit 64d089e

Browse files
lewtunqgallouedec
andauthored
Reasoning reward (#4563)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
1 parent 3b7d0e4 commit 64d089e

File tree

8 files changed

+195
-130
lines changed

8 files changed

+195
-130
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ trainer = GRPOTrainer(
104104
trainer.train()
105105
```
106106

107+
> [NOTE!]
108+
> For reasoning models, use the `reasoning_accuracy_reward()` function for better results.
109+
107110
### `DPOTrainer`
108111

109112
[`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer) implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train [Llama 3](https://huggingface.co/papers/2407.21783) and many other models. Here is a basic example of how to use the `DPOTrainer`:

docs/source/lora_without_regret.md

Lines changed: 2 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -135,94 +135,6 @@ The blog post performs GRPO on a range of models and datasets from the Hub, and
135135

136136
For reinforcement learning, the blog uses a math reasoning task that we can reproduce as a Python function.
137137

138-
<details>
139-
<summary>Reward function</summary>
140-
141-
```python
142-
def strip_reasoning_accuracy_reward(
143-
completions: list[list[dict[str, str]]], solution: list[str], **kwargs
144-
) -> list[float | None]:
145-
"""Reward function that strips reasoning tags and checks mathematical accuracy.
146-
147-
This function:
148-
1. Extracts the content from completions
149-
2. Removes <think></think> tags (for reasoning that shouldn't be evaluated)
150-
3. Parses both the gold solution and the predicted answer
151-
4. Uses math_verify to check if they are mathematically equivalent
152-
153-
Args:
154-
completions: List of model completions, each containing a list of messages
155-
solution: List of ground truth solutions
156-
**kwargs: Additional arguments (ignored but required for trainer compatibility)
157-
158-
Returns:
159-
List of rewards where:
160-
- 1.0 if the answer is correct
161-
- 0.0 if the answer is incorrect
162-
- None if the solution is not parseable (skips this example)
163-
"""
164-
contents = [completion[0]["content"] for completion in completions]
165-
rewards = []
166-
167-
for content, sol in zip(contents, solution):
168-
# Strip reasoning tags from completion
169-
while "<think>" in content and "</think>" in content:
170-
start = content.find("<think>")
171-
end = content.find("</think>", start)
172-
if start != -1 and end != -1:
173-
content = content[:start] + content[end + len("</think>") :]
174-
else:
175-
break
176-
177-
# Parse gold solution
178-
gold_parsed = parse(
179-
f"${sol}$",
180-
extraction_config=[
181-
LatexExtractionConfig(
182-
boxed_match_priority=0, try_extract_without_anchor=True
183-
)
184-
],
185-
)
186-
187-
if len(gold_parsed) != 0:
188-
# We require the answer to be provided in correct latex (no malformed operators)
189-
answer_parsed = parse(
190-
content,
191-
extraction_config=[
192-
LatexExtractionConfig(
193-
boxed_match_priority=0,
194-
normalization_config=NormalizationConfig(
195-
basic_latex=True,
196-
units=True,
197-
malformed_operators=False,
198-
nits=False,
199-
boxed=True,
200-
),
201-
try_extract_without_anchor=False,
202-
)
203-
],
204-
extraction_mode="first_match",
205-
)
206-
207-
# Compute binary rewards if verifiable, `None` otherwise to skip this example
208-
try:
209-
reward = float(verify(gold_parsed, answer_parsed))
210-
except Exception as e:
211-
print(
212-
f"verify failed: {e}, answer: {answer_parsed}, gold: {gold_parsed}"
213-
)
214-
reward = None
215-
else:
216-
# If the gold solution is not parseable, we assign `None` to skip this example
217-
reward = None
218-
219-
rewards.append(reward)
220-
221-
return rewards
222-
```
223-
224-
</details>
225-
226138
<hfoptions id="grpo">
227139
<hfoption id="python">
228140

@@ -233,14 +145,10 @@ We can implement these recommendations with the TRL Python API like so:
233145
from datasets import load_dataset
234146
from peft import LoraConfig
235147
from trl import GRPOConfig, GRPOTrainer
148+
from trl.rewards import reasoning_accuracy_reward
236149

237150
dataset = load_dataset("HuggingFaceH4/OpenR1-Math-220k-default-verified", split="train")
238151

239-
def strip_reasoning_accuracy_reward(completions, **kwargs):
240-
"""Reward function that strips reasoning and accuracy scores from the model outputs."""
241-
242-
...
243-
244152
peft_config = LoraConfig(
245153
r=1,
246154
lora_alpha=32,
@@ -259,7 +167,7 @@ training_args = GRPOConfig(
259167

260168
trainer = GRPOTrainer(
261169
model="Qwen/Qwen3-0.6B",
262-
reward_funcs=strip_reasoning_accuracy_reward,
170+
reward_funcs=reasoning_accuracy_reward,
263171
args=training_args,
264172
train_dataset=dataset,
265173
peft_config=peft_config,

docs/source/rewards.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ This module contains some useful reward functions, primarily intended for use wi
66

77
[[autodoc]] rewards.accuracy_reward
88

9+
## reasoning_accuracy_reward
10+
11+
[[autodoc]] rewards.reasoning_accuracy_reward
12+
913
## think_format_reward
1014

1115
[[autodoc]] rewards.think_format_reward

tests/test_rewards.py

Lines changed: 67 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward
15+
from trl.rewards import accuracy_reward, get_soft_overlong_punishment, reasoning_accuracy_reward, think_format_reward
1616

1717
from .testing_utils import TrlTestCase, require_math_latex
1818

@@ -117,27 +117,76 @@ def test_accuracy_reward_unparseable_gold(self):
117117
"""Test accuracy_reward with an unparseable gold solution."""
118118
completion = [
119119
[{"content": "Answer is forty two."}],
120-
[{"content": "Some other content."}],
121-
[{"content": r"Answer is \boxed{42}."}],
122-
[{"content": r"Answer is \boxed{\mathbf{42}}."}], # Make response bold
123-
[{"content": r"Answer is \boxed{\textbf{42}}."}], # Different latex command for bold
124-
[{"content": r"Answer is \boxed{42}."}],
125-
[{"content": r"Answer is \boxed{42.3456}."}],
120+
[{"content": r"Some other content. \boxed{43}."}],
126121
]
127122
solution = [
128123
"Answer is forty two.",
129124
"Answer is forty three.",
130-
"Answer is 42.0", # Decimal point
131-
"Answer is 42 43 okay?", # Extra space
132-
"Answer is 42",
133-
r"Answer is \n\boxed{42}", # Newline in gold solution
134-
"Answer is 42.34560", # Extra trailing zero
135125
]
136126
rewards = accuracy_reward(completion, solution)
137-
assert rewards[0] == 1.0 # Should revert to exact text match
127+
assert rewards[0] is None
128+
assert rewards[1] is None
129+
130+
131+
class TestReasoningAccuracyReward:
132+
@require_math_latex
133+
def test_correct_answer_yields_unit_reward(self):
134+
completions = [
135+
[{"content": r"<think> Reasoning content </think> \boxed{\frac{63}{400}}"}],
136+
[{"content": r"Reasoning content </think> \boxed{\frac{63}{400}}"}],
137+
]
138+
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
139+
rewards = reasoning_accuracy_reward(completions, solutions)
140+
assert rewards[0] == 1.0
141+
assert rewards[1] == 1.0
142+
143+
@require_math_latex
144+
def test_correct_answer_with_custom_tags_yields_unit_reward(self):
145+
completions = [
146+
[{"content": r"<REASONING_START> Reasoning content </REASONING_END> \boxed{\frac{63}{400}}"}],
147+
]
148+
solutions = [
149+
r"\frac{63}{400}",
150+
]
151+
rewards = reasoning_accuracy_reward(completions, solutions, reasoning_delimiters=["</REASONING_END>"])
152+
assert rewards[0] == 1.0
153+
154+
@require_math_latex
155+
def test_incorrect_answer_yields_zero_reward(self):
156+
completion = [[{"content": r"<think> Reasoning content </think> \boxed{\frac{64}{400}}"}]]
157+
solution = [r"\frac{63}{400}"]
158+
rewards = reasoning_accuracy_reward(completion, solution)
159+
assert rewards[0] == 0.0
160+
161+
@require_math_latex
162+
def test_correct_answer_in_reasoning_yields_zero_reward(self):
163+
completions = [
164+
[{"content": r"<think> My answer is \boxed{42} </think> Some other text."}],
165+
[{"content": r"<think> The answer is \boxed{42} </think> Here's a wrong answer: \boxed{43}."}],
166+
]
167+
solutions = [r"\boxed{42}", r"\boxed{42}"]
168+
rewards = reasoning_accuracy_reward(completions, solutions)
169+
assert rewards[0] == 0.0
170+
assert rewards[1] == 0.0
171+
172+
@require_math_latex
173+
def test_incomplete_reasoning_yields_zero_reward(self):
174+
completions = [
175+
[{"content": r"<think> Incomplete reasoning without closing tag"}],
176+
[{"content": r"Correct answer \frac{63}{400} but completely missing reasoning content"}],
177+
]
178+
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
179+
rewards = reasoning_accuracy_reward(completions, solutions)
180+
assert rewards[0] == 0.0
138181
assert rewards[1] == 0.0
139-
assert rewards[2] == 1.0
140-
assert rewards[3] == 1.0
141-
assert rewards[4] == 1.0
142-
assert rewards[5] == 1.0
143-
assert rewards[6] == 1.0 # Should ignore trailing zeros
182+
183+
@require_math_latex
184+
def test_unparseable_gold_solution_yields_none_reward(self):
185+
completions = [
186+
[{"content": r"<think> Reasoning content </think> \boxed{42}"}],
187+
]
188+
solutions = [
189+
"forty two",
190+
]
191+
rewards = reasoning_accuracy_reward(completions, solutions)
192+
assert rewards[0] is None

trl/rewards/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919

2020

2121
_import_structure = {
22-
"accuracy_rewards": ["accuracy_reward"],
22+
"accuracy_rewards": ["accuracy_reward", "reasoning_accuracy_reward"],
2323
"format_rewards": ["think_format_reward"],
2424
"other_rewards": ["get_soft_overlong_punishment"],
2525
}
2626

2727

2828
if TYPE_CHECKING:
29-
from .accuracy_rewards import accuracy_reward
29+
from .accuracy_rewards import accuracy_reward, reasoning_accuracy_reward
3030
from .format_rewards import think_format_reward
3131
from .other_rewards import get_soft_overlong_punishment
3232

0 commit comments

Comments
 (0)