Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add E2B code interpreter reward function #364

Merged
merged 38 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
e0decfd
Add stuff
lewtun Feb 13, 2025
a290599
Merge branch 'main' into grpo-code
lewtun Feb 13, 2025
7f4e8a3
Merge branch 'main' into grpo-code
lewtun Feb 13, 2025
da19783
Make it kind of work
lewtun Feb 13, 2025
6ba5302
Add more stuff
lewtun Feb 13, 2025
f8e200e
Merge branch 'main' into grpo-code
lewtun Feb 13, 2025
78cf722
Add fix for parse
lewtun Feb 13, 2025
24dc34f
Fix
lewtun Feb 13, 2025
22244fe
Refactor
lewtun Feb 13, 2025
c32d137
Clean up
lewtun Feb 13, 2025
dab15e0
Fix config
lewtun Feb 13, 2025
edc502d
Fix sys
lewtun Feb 14, 2025
27af68e
Add SFT config
lewtun Feb 15, 2025
53eaddb
Use min rate
lewtun Feb 15, 2025
385d799
Fix eval
lewtun Feb 16, 2025
52fc681
Add base model
lewtun Feb 16, 2025
884387f
Add s1k
lewtun Feb 16, 2025
2d3c797
Disable eval
lewtun Feb 17, 2025
f85b7b7
Merge branch 'main' into grpo-code
lewtun Feb 18, 2025
aaa8f6f
Fix
lewtun Feb 18, 2025
20a1ea0
Add import checker
lewtun Feb 18, 2025
5863303
Fix importer
lewtun Feb 18, 2025
8d78b8e
Fix
lewtun Feb 18, 2025
932e69e
Tune config
lewtun Feb 18, 2025
258406f
Tune
lewtun Feb 18, 2025
fd9860e
Fix
lewtun Feb 18, 2025
c614dbd
Fix save
lewtun Feb 18, 2025
51815b2
Tuen beta
lewtun Feb 18, 2025
21c9859
Merge branch 'main' into grpo-code
lewtun Feb 18, 2025
da08407
Remove configs
lewtun Feb 18, 2025
5f35a61
Fix vLLM
lewtun Feb 18, 2025
93254b4
Fix
lewtun Feb 18, 2025
853e42b
Add note
lewtun Feb 18, 2025
23dfafd
Add doc
lewtun Feb 19, 2025
65c44d8
doc
lewtun Feb 19, 2025
04381ca
Fix
lewtun Feb 19, 2025
fb6e4ae
Tune lr
lewtun Feb 19, 2025
89ded43
Add command
lewtun Feb 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,43 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con

Our final [model](https://huggingface.co/Dongwei/Qwen-2.5-7B_Base_Math_smalllr), while using different learning rates, loss functions and reward structures, achieves 69.4% accuracy on MATH-500, demonstrating a 17%+ improvement over the base model.

#### 👨‍💻 Training with a code interpreter

We provide a `code` reward function for executing code generated by the policy during training. Currently, this reward function targets code contests like [Codeforces](https://codeforces.com), where solutions are executed against a set of test cases and the overall success rate is returned as the final reward. To ensure safe execution, we use [E2B](https://e2b.dev) sandboxes, which are fast and cheap to run. To use this reward function, first install the necessary dependencies:

```shell
uv pip install -e '.[code]'
```

Then create a `.env` file and place an API token from E2B within it:

```
E2B_API_KEY="e2b_xxx"
```

Then make sure your dataset contains a `verification_info` column with the following schema (adopted from PrimeIntellect's excellent [datasets](https://huggingface.co/collections/PrimeIntellect/synthetic-1-67a2c399cfdd6c9f7fae0c37) of verifiable problems):

```python
{
"language": "python",
"test_cases": [
{
"input": "4\n4\n0001\n1000\n0011\n0111\n3\n010\n101\n0\n2\n00000\n00001\n4\n01\n001\n0001\n00001\n",
"output": "1\n3 \n-1\n0\n\n2\n1 2 \n",
"type": "stdin_stdout",
}
],
}
```

For example, to train a smol model on Python problems, run:

```shell
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \
--num_processes=7 src/open_r1/grpo.py \
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code.yaml
```

### Launching jobs on a Slurm cluster

If you have access to a Slurm cluster, we provide a `slurm/train.slurm` script that will automatically queue training jobs for you. Here's how you can use it:
Expand Down
57 changes: 57 additions & 0 deletions recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Model arguments
model_name_or_path: Qwen/Qwen2.5-1.5B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2

# Data training arguments
dataset_name: open-r1/verifiable-coding-problems-python-10k
dataset_configs:
- default
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"

# GRPO trainer config
beta: 0.01
bf16: true
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.9
do_eval: false
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: Qwen2.5-1.5B-Open-R1-Code-GRPO
hub_strategy: every_save
learning_rate: 5.0e-06
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
max_prompt_length: 1024
max_completion_length: 2048
max_steps: 500
num_generations: 14
num_train_epochs: 1
output_dir: data/Qwen2.5-1.5B-Open-R1-Code-GRPO
overwrite_output_dir: true
per_device_train_batch_size: 16
push_to_hub: true
report_to:
- wandb
reward_funcs:
- code
- format
reward_weights:
- 1.0
- 0.1
save_strategy: "steps"
save_steps: 50
save_total_limit: 1
seed: 42
temperature: 1.0
warmup_ratio: 0.03
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"datasets>=3.2.0",
"deepspeed==0.15.4",
"distilabel[vllm,ray,openai]>=1.5.2",
"e2b-code-interpreter>=1.0.5",
"einops>=0.8.0",
"flake8>=6.0.0",
"flash_attn>=2.7.4.post1",
Expand All @@ -60,6 +61,7 @@
"parameterized>=0.9.0",
"peft>=0.14.0",
"pytest",
"python-dotenv",
"ruff>=0.9.0",
"safetensors>=0.3.3",
"sentencepiece>=0.1.99",
Expand Down Expand Up @@ -88,6 +90,7 @@ def deps_list(*pkgs):
extras["torch"] = deps_list("torch")
extras["quality"] = deps_list("ruff", "isort", "flake8")
extras["train"] = deps_list("flash_attn")
extras["code"] = deps_list("e2b-code-interpreter", "python-dotenv")
extras["eval"] = deps_list("lighteval", "math-verify")
extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"] + extras["train"]

Expand Down
2 changes: 2 additions & 0 deletions src/open_r1/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from open_r1.configs import GRPOConfig
from open_r1.rewards import (
accuracy_reward,
code_reward,
format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
Expand Down Expand Up @@ -161,6 +162,7 @@ def main(script_args, training_args, model_args):
max_penalty=script_args.repetition_max_penalty,
),
"length": len_reward,
"code": code_reward,
}
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]

Expand Down
87 changes: 87 additions & 0 deletions src/open_r1/rewards.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
"""Reward functions for GRPO training."""

import json
import math
import re
from typing import Dict

from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify

from .utils import is_e2b_available


if is_e2b_available():
from dotenv import load_dotenv
from e2b_code_interpreter import Sandbox

load_dotenv()


def accuracy_reward(completions, solution, **kwargs):
"""Reward function that checks if the completion is the same as the ground truth."""
Expand Down Expand Up @@ -271,3 +281,80 @@ def repetition_penalty_reward(completions, **kwargs) -> float:
return rewards

return repetition_penalty_reward


def extract_code(completion: str) -> str:
pattern = re.compile(r"```python\n(.*?)```", re.DOTALL)
matches = pattern.findall(completion)
extracted_answer = matches[-1] if len(matches) >= 1 else ""
return extracted_answer


def code_reward(completions, **kwargs) -> list[float]:
"""Reward function that evaluates code snippets using the E2B code interpreter.
Assumes the dataset contains a `verification_info` column with test cases.
"""
if not is_e2b_available():
raise ImportError(
"E2B is not available and required for this reward function. Please install E2B with "
"`pip install e2b-code-interpreter` and add an API key to a `.env` file."
)

rewards = []
# TODO: add support for other languages in E2B: https://e2b.dev/docs/code-interpreting/supported-languages
try:
"""Returns a reward function that evaluates code snippets in a sandbox."""
evaluation_script_template = """
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's surprising that you don't have any issue with the extra indentation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually use textwrap.dedent in this case, but it might not be necessary here for some reason.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import subprocess
import json
def evaluate_code(code, test_cases):
passed = 0
total = len(test_cases)
exec_timeout = 5
for case in test_cases:
process = subprocess.run(
["python3", "-c", code],
input=case["input"],
text=True,
capture_output=True,
timeout=exec_timeout
)
if process.returncode != 0: # Error in execution
continue
output = process.stdout.strip()
if output.strip() == case["output"].strip():
passed += 1
success_rate = (passed / total)
return success_rate
code_snippet = {code}
test_cases = json.loads({test_cases})
evaluate_code(code_snippet, test_cases)
"""
code_snippets = [extract_code(completion[-1]["content"]) for completion in completions]
verification_info = kwargs["verification_info"]
scripts = [
evaluation_script_template.format(
code=json.dumps(code), test_cases=json.dumps(json.dumps(info["test_cases"]))
)
for code, info in zip(code_snippets, verification_info)
]
with Sandbox(timeout=30, request_timeout=3) as sbx:
for script in scripts:
execution = sbx.run_code(script, language=verification_info["language"])
try:
output = float(execution.text)
except (TypeError, ValueError):
output = 0.0
rewards.append(output)
except Exception as e:
print(f"Error from E2B executor: {e}")
rewards = [0.0] * len(completions)
return rewards
3 changes: 2 additions & 1 deletion src/open_r1/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .import_utils import is_e2b_available
from .model_utils import get_tokenizer


__all__ = ["get_tokenizer"]
__all__ = ["get_tokenizer", "is_e2b_available"]
23 changes: 23 additions & 0 deletions src/open_r1/utils/import_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from transformers.utils.import_utils import _is_package_available


# Use same as transformers.utils.import_utils
_e2b_available = _is_package_available("e2b")


def is_e2b_available() -> bool:
return _e2b_available