diff --git a/README.md b/README.md
index a14f180f..1111231a 100644
--- a/README.md
+++ b/README.md
@@ -170,6 +170,43 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con
Our final [model](https://huggingface.co/Dongwei/Qwen-2.5-7B_Base_Math_smalllr), while using different learning rates, loss functions and reward structures, achieves 69.4% accuracy on MATH-500, demonstrating a 17%+ improvement over the base model.
+#### 👨💻 Training with a code interpreter
+
+We provide a `code` reward function for executing code generated by the policy during training. Currently, this reward function targets code contests like [Codeforces](https://codeforces.com), where solutions are executed against a set of test cases and the overall success rate is returned as the final reward. To ensure safe execution, we use [E2B](https://e2b.dev) sandboxes, which are fast and cheap to run. To use this reward function, first install the necessary dependencies:
+
+```shell
+uv pip install -e '.[code]'
+```
+
+Then create a `.env` file and place an API token from E2B within it:
+
+```
+E2B_API_KEY="e2b_xxx"
+```
+
+Then make sure your dataset contains a `verification_info` column with the following schema (adopted from PrimeIntellect's excellent [datasets](https://huggingface.co/collections/PrimeIntellect/synthetic-1-67a2c399cfdd6c9f7fae0c37) of verifiable problems):
+
+```python
+{
+ "language": "python",
+ "test_cases": [
+ {
+ "input": "4\n4\n0001\n1000\n0011\n0111\n3\n010\n101\n0\n2\n00000\n00001\n4\n01\n001\n0001\n00001\n",
+ "output": "1\n3 \n-1\n0\n\n2\n1 2 \n",
+ "type": "stdin_stdout",
+ }
+ ],
+}
+```
+
+For example, to train a smol model on Python problems, run:
+
+```shell
+ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \
+ --num_processes=7 src/open_r1/grpo.py \
+ --config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code.yaml
+```
+
### Launching jobs on a Slurm cluster
If you have access to a Slurm cluster, we provide a `slurm/train.slurm` script that will automatically queue training jobs for you. Here's how you can use it:
diff --git a/recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code.yaml b/recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code.yaml
new file mode 100644
index 00000000..783a4d2a
--- /dev/null
+++ b/recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code.yaml
@@ -0,0 +1,57 @@
+# Model arguments
+model_name_or_path: Qwen/Qwen2.5-1.5B-Instruct
+model_revision: main
+torch_dtype: bfloat16
+attn_implementation: flash_attention_2
+
+# Data training arguments
+dataset_name: open-r1/verifiable-coding-problems-python-10k
+dataset_configs:
+- default
+system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: \n...\n\n\n...\n"
+
+# GRPO trainer config
+beta: 0.01
+bf16: true
+use_vllm: true
+vllm_device: auto
+vllm_gpu_memory_utilization: 0.9
+do_eval: false
+gradient_accumulation_steps: 4
+gradient_checkpointing: true
+gradient_checkpointing_kwargs:
+ use_reentrant: false
+hub_model_id: Qwen2.5-1.5B-Open-R1-Code-GRPO
+hub_strategy: every_save
+learning_rate: 5.0e-06
+log_completions: true
+log_level: info
+logging_first_step: true
+logging_steps: 1
+logging_strategy: steps
+lr_scheduler_type: cosine_with_min_lr
+lr_scheduler_kwargs:
+ min_lr_rate: 0.1
+max_prompt_length: 1024
+max_completion_length: 2048
+max_steps: 500
+num_generations: 14
+num_train_epochs: 1
+output_dir: data/Qwen2.5-1.5B-Open-R1-Code-GRPO
+overwrite_output_dir: true
+per_device_train_batch_size: 16
+push_to_hub: true
+report_to:
+- wandb
+reward_funcs:
+- code
+- format
+reward_weights:
+- 1.0
+- 0.1
+save_strategy: "steps"
+save_steps: 50
+save_total_limit: 1
+seed: 42
+temperature: 1.0
+warmup_ratio: 0.03
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 231de49a..907269c2 100644
--- a/setup.py
+++ b/setup.py
@@ -46,6 +46,7 @@
"datasets>=3.2.0",
"deepspeed==0.15.4",
"distilabel[vllm,ray,openai]>=1.5.2",
+ "e2b-code-interpreter>=1.0.5",
"einops>=0.8.0",
"flake8>=6.0.0",
"flash_attn>=2.7.4.post1",
@@ -60,6 +61,7 @@
"parameterized>=0.9.0",
"peft>=0.14.0",
"pytest",
+ "python-dotenv",
"ruff>=0.9.0",
"safetensors>=0.3.3",
"sentencepiece>=0.1.99",
@@ -88,6 +90,7 @@ def deps_list(*pkgs):
extras["torch"] = deps_list("torch")
extras["quality"] = deps_list("ruff", "isort", "flake8")
extras["train"] = deps_list("flash_attn")
+extras["code"] = deps_list("e2b-code-interpreter", "python-dotenv")
extras["eval"] = deps_list("lighteval", "math-verify")
extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"] + extras["train"]
diff --git a/src/open_r1/grpo.py b/src/open_r1/grpo.py
index 7032346b..2ead27da 100644
--- a/src/open_r1/grpo.py
+++ b/src/open_r1/grpo.py
@@ -27,6 +27,7 @@
from open_r1.configs import GRPOConfig
from open_r1.rewards import (
accuracy_reward,
+ code_reward,
format_reward,
get_cosine_scaled_reward,
get_repetition_penalty_reward,
@@ -161,6 +162,7 @@ def main(script_args, training_args, model_args):
max_penalty=script_args.repetition_max_penalty,
),
"length": len_reward,
+ "code": code_reward,
}
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
diff --git a/src/open_r1/rewards.py b/src/open_r1/rewards.py
index 3deea193..c003f4d3 100644
--- a/src/open_r1/rewards.py
+++ b/src/open_r1/rewards.py
@@ -1,5 +1,6 @@
"""Reward functions for GRPO training."""
+import json
import math
import re
from typing import Dict
@@ -7,6 +8,15 @@
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
+from .utils import is_e2b_available
+
+
+if is_e2b_available():
+ from dotenv import load_dotenv
+ from e2b_code_interpreter import Sandbox
+
+ load_dotenv()
+
def accuracy_reward(completions, solution, **kwargs):
"""Reward function that checks if the completion is the same as the ground truth."""
@@ -271,3 +281,80 @@ def repetition_penalty_reward(completions, **kwargs) -> float:
return rewards
return repetition_penalty_reward
+
+
+def extract_code(completion: str) -> str:
+ pattern = re.compile(r"```python\n(.*?)```", re.DOTALL)
+ matches = pattern.findall(completion)
+ extracted_answer = matches[-1] if len(matches) >= 1 else ""
+ return extracted_answer
+
+
+def code_reward(completions, **kwargs) -> list[float]:
+ """Reward function that evaluates code snippets using the E2B code interpreter.
+
+ Assumes the dataset contains a `verification_info` column with test cases.
+ """
+ if not is_e2b_available():
+ raise ImportError(
+ "E2B is not available and required for this reward function. Please install E2B with "
+ "`pip install e2b-code-interpreter` and add an API key to a `.env` file."
+ )
+
+ rewards = []
+ # TODO: add support for other languages in E2B: https://e2b.dev/docs/code-interpreting/supported-languages
+ try:
+ """Returns a reward function that evaluates code snippets in a sandbox."""
+ evaluation_script_template = """
+ import subprocess
+ import json
+
+ def evaluate_code(code, test_cases):
+ passed = 0
+ total = len(test_cases)
+ exec_timeout = 5
+
+ for case in test_cases:
+ process = subprocess.run(
+ ["python3", "-c", code],
+ input=case["input"],
+ text=True,
+ capture_output=True,
+ timeout=exec_timeout
+ )
+
+ if process.returncode != 0: # Error in execution
+ continue
+
+ output = process.stdout.strip()
+ if output.strip() == case["output"].strip():
+ passed += 1
+
+ success_rate = (passed / total)
+ return success_rate
+
+ code_snippet = {code}
+ test_cases = json.loads({test_cases})
+
+ evaluate_code(code_snippet, test_cases)
+ """
+ code_snippets = [extract_code(completion[-1]["content"]) for completion in completions]
+ verification_info = kwargs["verification_info"]
+ scripts = [
+ evaluation_script_template.format(
+ code=json.dumps(code), test_cases=json.dumps(json.dumps(info["test_cases"]))
+ )
+ for code, info in zip(code_snippets, verification_info)
+ ]
+ with Sandbox(timeout=30, request_timeout=3) as sbx:
+ for script in scripts:
+ execution = sbx.run_code(script, language=verification_info["language"])
+ try:
+ output = float(execution.text)
+ except (TypeError, ValueError):
+ output = 0.0
+ rewards.append(output)
+ except Exception as e:
+ print(f"Error from E2B executor: {e}")
+ rewards = [0.0] * len(completions)
+ return rewards
diff --git a/src/open_r1/utils/__init__.py b/src/open_r1/utils/__init__.py
index b1de213d..5302463e 100644
--- a/src/open_r1/utils/__init__.py
+++ b/src/open_r1/utils/__init__.py
@@ -1,4 +1,5 @@
+from .import_utils import is_e2b_available
from .model_utils import get_tokenizer
-__all__ = ["get_tokenizer"]
+__all__ = ["get_tokenizer", "is_e2b_available"]
diff --git a/src/open_r1/utils/import_utils.py b/src/open_r1/utils/import_utils.py
new file mode 100644
index 00000000..8893264a
--- /dev/null
+++ b/src/open_r1/utils/import_utils.py
@@ -0,0 +1,23 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from transformers.utils.import_utils import _is_package_available
+
+
+# Use same as transformers.utils.import_utils
+_e2b_available = _is_package_available("e2b")
+
+
+def is_e2b_available() -> bool:
+ return _e2b_available