Skip to content

Commit

Permalink
add cosine reward
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Feb 6, 2025
1 parent f8cbb98 commit 7913bc1
Showing 1 changed file with 78 additions and 2 deletions.
80 changes: 78 additions & 2 deletions src/open_r1/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import logging
import math
import os
import re
import sys
Expand Down Expand Up @@ -46,8 +47,8 @@ class GRPOScriptArguments(ScriptArguments):
"""

reward_funcs: list[str] = field(
default_factory=lambda: ["accuracy", "format"],
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
default_factory=lambda: ["accuracy", "format", "cosine"],
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format', 'cosine'"},
)


Expand Down Expand Up @@ -97,9 +98,84 @@ def format_reward(completions, **kwargs):
return [1.0 if match else 0.0 for match in matches]


def cosine_scaled_reward(completions, solution, **kwargs):
"""Reward function that scales based on completion length using a cosine schedule.
Shorter correct solutions are rewarded more than longer ones.
Longer incorrect solutions are penalized less than shorter ones.
Args:
completions: List of model completions
solution: List of ground truth solutions
**kwargs: Additional arguments including:
min_value_wrong (float): Minimum reward for wrong answers (default: -1.0)
max_value_wrong (float): Maximum reward for wrong answers (default: -0.5)
min_value_correct (float): Minimum reward for correct answers (default: 0.5)
max_value_correct (float): Maximum reward for correct answers (default: 1.0)
max_len (int): Maximum length for scaling (default: 1000)
"""
contents = [completion[0]["content"] for completion in completions]
rewards = []

# Get constants from kwargs with defaults
min_value_wrong = kwargs.get("min_value_wrong", -1.0)
max_value_wrong = kwargs.get("max_value_wrong", -0.5)
min_value_correct = kwargs.get("min_value_correct", 0.5)
max_value_correct = kwargs.get("max_value_correct", 1.0)
max_len = kwargs.get("max_len", 1000)

for content, sol in zip(contents, solution):
# First check if the answer is correct using existing parse/verify logic
gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
if len(gold_parsed) == 0:
rewards.append(1.0) # Skip unparseable examples
print("Failed to parse gold solution: ", sol)
continue

answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed=True,
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)

is_correct = verify(answer_parsed, gold_parsed)
gen_len = len(content)

# Apply cosine scaling based on length
progress = gen_len / max_len
cosine = math.cos(progress * math.pi)

if is_correct:
min_value = min_value_correct
max_value = max_value_correct
else:
# Swap min/max for incorrect answers
min_value = max_value_wrong
max_value = min_value_wrong

reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine)
rewards.append(float(reward))

return rewards


reward_funcs_registry = {
"accuracy": accuracy_reward,
"format": format_reward,
"cosine": cosine_scaled_reward,
}

SYSTEM_PROMPT = (
Expand Down

0 comments on commit 7913bc1

Please sign in to comment.