Skip to content

Commit

Permalink
[Weighted reward functions] Adding functionality to weigh rewards. Te…
Browse files Browse the repository at this point in the history
…sts.
  • Loading branch information
zeenolife committed Feb 10, 2025
1 parent 517addd commit 031ec05
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 2 deletions.
6 changes: 6 additions & 0 deletions recipes/DeepSeek-R1-Distill-Qwen-7B/grpo/config_demo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ per_device_train_batch_size: 2
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 1.0
save_strategy: "no"
seed: 42
warmup_ratio: 0.1
6 changes: 6 additions & 0 deletions recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ per_device_train_batch_size: 2
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 1.0
save_strategy: "no"
seed: 42
warmup_ratio: 0.1
6 changes: 6 additions & 0 deletions recipes/Qwen2.5-Math-7B/grpo/config_simple_rl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ per_device_train_batch_size: 2
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 1.0
save_strategy: "no"
seed: 42
warmup_ratio: 0.1
29 changes: 27 additions & 2 deletions src/open_r1/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
import sys
from dataclasses import dataclass, field
from typing import Optional

import datasets
import torch
Expand All @@ -27,6 +28,7 @@
from open_r1.configs import GRPOConfig
from open_r1.rewards import (
accuracy_reward,
create_weighted_reward,
format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
Expand All @@ -47,6 +49,8 @@ class GRPOScriptArguments(ScriptArguments):
Args:
reward_funcs (`list[str]`):
List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty'.
reward_weights (`list[float]` or `None`, *optional*):
List of weights for each reward function. If not provided, defaults to 1.0 for each reward function.
cosine_min_value_wrong (`float`):
Minimum reward for cosine scaling for wrong answers.
cosine_max_value_wrong (`float`):
Expand All @@ -57,6 +61,7 @@ class GRPOScriptArguments(ScriptArguments):
Maximum reward for cosine scaling for correct answers.
cosine_max_len (`int`):
Maximum length for cosine scaling.
"""

reward_funcs: list[str] = field(
Expand All @@ -65,6 +70,12 @@ class GRPOScriptArguments(ScriptArguments):
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty'"
},
)
reward_weights: Optional[list[float]] = field(
default=None,
metadata={
"help": "List of weights for each reward function. If not provided, defaults to 1.0 for each function."
},
)
cosine_min_value_wrong: float = field(
default=0.0,
metadata={"help": "Minimum reward for wrong answers"},
Expand All @@ -86,6 +97,17 @@ class GRPOScriptArguments(ScriptArguments):
metadata={"help": "Maximum length for scaling"},
)

def __post_init__(self):
# If no weights were provided, default to 1.0 for each reward function
if self.reward_weights is None:
self.reward_weights = [1.0] * len(self.reward_funcs)
# If weights were provided, validate the length
elif len(self.reward_weights) != len(self.reward_funcs):
raise ValueError(
f"Number of reward weights ({len(self.reward_weights)}: {self.reward_weights}) must match "
f"number of reward functions ({len(self.reward_funcs)}: {self.reward_funcs})"
)

repetition_n_grams: int = field(
default=3,
metadata={"help": "Number of n-grams for repetition penalty reward"},
Expand Down Expand Up @@ -142,7 +164,7 @@ def main(script_args, training_args, model_args):
# Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

# Get reward functions
# Create weighted reward functions
REWARD_FUNCS_REGISTRY = {
"accuracy": accuracy_reward,
"format": format_reward,
Expand All @@ -159,7 +181,10 @@ def main(script_args, training_args, model_args):
max_penalty=script_args.repetition_max_penalty,
),
}
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
reward_funcs = [
create_weighted_reward(REWARD_FUNCS_REGISTRY[func], weight)
for func, weight in zip(script_args.reward_funcs, script_args.reward_weights)
]

# Format into conversation
def make_conversation(example):
Expand Down
18 changes: 18 additions & 0 deletions src/open_r1/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,21 @@ def repetition_penalty_reward(completions, **kwargs) -> float:
return rewards

return repetition_penalty_reward


def create_weighted_reward(func, weight):
"""Create a weighted version of a reward function.
Args:
func: The reward function to weight
weight: The weight to apply to the reward
Returns:
A new function that applies the weight to the reward
"""

def weighted_reward(*args, **kwargs):
rewards = func(*args, **kwargs)
return [r * weight for r in rewards]

return weighted_reward
40 changes: 40 additions & 0 deletions tests/test_grpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import unittest

from open_r1.grpo import GRPOScriptArguments


class TestGRPOScriptArguments(unittest.TestCase):
def test_default_weights(self):
"""Test that default weights are correctly set when not provided."""
args = GRPOScriptArguments(dataset_name="ABC")
self.assertEqual(len(args.reward_funcs), len(args.reward_weights))
self.assertEqual(args.reward_weights, [1.0] * len(args.reward_funcs))

def test_custom_weights_valid(self):
"""Test that custom weights are accepted when matching reward_funcs length."""
args = GRPOScriptArguments(
dataset_name="ABC", reward_funcs=["accuracy", "format", "reasoning_steps"], reward_weights=[0.5, 1.0, 2.0]
)
self.assertEqual(args.reward_weights, [0.5, 1.0, 2.0])

def test_custom_weights_invalid(self):
"""Test that mismatched weights raise ValueError."""
with self.assertRaises(ValueError) as context:
GRPOScriptArguments(
dataset_name="ABC", reward_funcs=["accuracy", "format"], reward_weights=[1.0, 2.0, 3.0]
)
self.assertIn("Number of reward weights", str(context.exception))
self.assertIn("must match number of reward functions", str(context.exception))

def test_empty_weights_with_custom_funcs(self):
"""Test that empty weights are filled with 1.0 for custom reward functions."""
args = GRPOScriptArguments(
dataset_name="ABC",
reward_funcs=["accuracy", "format", "reasoning_steps"],
)
self.assertEqual(len(args.reward_weights), 3)
self.assertEqual(args.reward_weights, [1.0, 1.0, 1.0])


if __name__ == "__main__":
unittest.main()
30 changes: 30 additions & 0 deletions tests/test_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from open_r1.rewards import (
accuracy_reward,
create_weighted_reward,
format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
Expand Down Expand Up @@ -75,6 +76,35 @@ def test_multiple_completions(self):
self.assertEqual(rewards[0], 1.0)
self.assertEqual(rewards[1], 0.0)

def test_weighted_reward(self):
"""Test create_weighted_reward with different weights."""
# Test with weight = 2.0
completion = [[{"content": "<think>Some reasoning</think><answer>The answer</answer>"}]]
base_reward_func = format_reward
weighted_reward_func = create_weighted_reward(base_reward_func, 2.0)

base_rewards = base_reward_func(completion)
weighted_rewards = weighted_reward_func(completion)

self.assertEqual(base_rewards[0], 1.0)
self.assertEqual(weighted_rewards[0], 2.0)

# Test with weight = 0.5
weighted_reward_func = create_weighted_reward(base_reward_func, 0.5)
weighted_rewards = weighted_reward_func(completion)
self.assertEqual(weighted_rewards[0], 0.5)

# Test with multiple completions
completions = [
[{"content": "<think>Some reasoning</think><answer>The answer</answer>"}],
[{"content": "Invalid format"}],
]
weighted_reward_func = create_weighted_reward(base_reward_func, 2.0)
weighted_rewards = weighted_reward_func(completions)

self.assertEqual(weighted_rewards[0], 2.0)
self.assertEqual(weighted_rewards[1], 0.0)

def test_cosine_scaled_reward(self):
"""Test cosine_scaled_reward with various cases."""
# Test parameters
Expand Down

0 comments on commit 031ec05

Please sign in to comment.