diff --git a/Makefile b/Makefile index 1140e59e..17e2ddc1 100644 --- a/Makefile +++ b/Makefile @@ -26,7 +26,7 @@ evaluate: fi \ ),)) $(if $(filter tensor,$(PARALLEL)),export VLLM_WORKER_MULTIPROC_METHOD=spawn &&,) \ - MODEL_ARGS="pretrained=$(MODEL),dtype=float16,$(PARALLEL_ARGS),max_model_length=32768,gpu_memory_utilisation=0.8" && \ + MODEL_ARGS="pretrained=$(MODEL),dtype=bfloat16,$(PARALLEL_ARGS),max_model_length=32768,gpu_memory_utilisation=0.8" && \ lighteval vllm $$MODEL_ARGS "custom|$(TASK)|0|0" \ --custom-tasks src/open_r1/evaluate.py \ --use-chat-template \ diff --git a/README.md b/README.md index f02f8e09..0faffa2e 100644 --- a/README.md +++ b/README.md @@ -50,23 +50,23 @@ To install `uv`, follow the [UV Installation Guide](https://docs.astral.sh/uv/ge ```shell -uv venv openr1 --python 3.11 && source openr1/bin/activate && uv pip install --upgrade pip +uv venv openr1 --python 3.11 && source openr1/bin/activate && uv pip install --upgrade pip --link-mode=copy ``` Next, install vLLM: ```shell -uv pip install vllm>=0.7.0 +uv pip install vllm==0.7.1 # For CUDA 12.1 -pip install vllm>=0.7.0 --extra-index-url https://download.pytorch.org/whl/cu121 +uv pip install vllm==0.7.1 --extra-index-url https://download.pytorch.org/whl/cu121 --index-strategy unsafe-best-match --link-mode=copy export LD_LIBRARY_PATH=$(python -c "import site; print(site.getsitepackages()[0] + '/nvidia/nvjitlink/lib')"):$LD_LIBRARY_PATH ``` This will also install PyTorch `v2.5.1` and it is **very important** to use this version since the vLLM binaries are compiled for it. You can then install the remaining dependencies for your specific use case via `pip install -e .[LIST OF MODES]`. For most contributors, we recommend: ```shell -pip install -e ".[dev]" +GIT_LFS_SKIP_SMUDGE=1 uv pip install -e ".[dev]" --link-mode=copy ``` Next, log into your Hugging Face and Weights and Biases accounts as follows: @@ -141,30 +141,46 @@ We use `lighteval` to evaluate models, with custom tasks defined in `src/open_r1 ```shell MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B -MODEL_ARGS="pretrained=$MODEL,dtype=float16,max_model_length=32768,gpu_memory_utilisation=0.8" -TASK=aime24 +MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilisation=0.8" OUTPUT_DIR=data/evals/$MODEL +# AIME 2024 +TASK=aime24 +lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \ + --custom-tasks src/open_r1/evaluate.py \ + --use-chat-template \ + --output-dir $OUTPUT_DIR + +# MATH-500 +TASK=math_500 +lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \ + --custom-tasks src/open_r1/evaluate.py \ + --use-chat-template \ + --output-dir $OUTPUT_DIR + +# GPQA Diamond +TASK=gpqa:diamond lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \ --custom-tasks src/open_r1/evaluate.py \ --use-chat-template \ - --system-prompt="Please reason step by step, and put your final answer within \boxed{}." \ --output-dir $OUTPUT_DIR ``` +> [!IMPORTANT] +> You must set `max_model_length=32768` in the `vllm` command to align with the `generation_size` we define per eval. Without this, `lighteval` will throw an error. + To increase throughput across multiple GPUs, use _data parallel_ as follows: ```shell NUM_GPUS=8 MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B -MODEL_ARGS="pretrained=$MODEL,dtype=float16,data_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilisation=0.8" +MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,data_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilisation=0.8" TASK=aime24 OUTPUT_DIR=data/evals/$MODEL lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \ --custom-tasks src/open_r1/evaluate.py \ --use-chat-template \ - --system-prompt="Please reason step by step, and put your final answer within \boxed{}." \ --output-dir $OUTPUT_DIR ``` @@ -173,7 +189,7 @@ For large models which require sharding across GPUs, use _tensor parallel_ and r ```shell NUM_GPUS=8 MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B -MODEL_ARGS="pretrained=$MODEL,dtype=float16,tensor_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilisation=0.8" +MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,tensor_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilisation=0.8" TASK=aime24 OUTPUT_DIR=data/evals/$MODEL @@ -181,50 +197,97 @@ export VLLM_WORKER_MULTIPROC_METHOD=spawn lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \ --custom-tasks src/open_r1/evaluate.py \ --use-chat-template \ - --system-prompt="Please reason step by step, and put your final answer within \boxed{}." \ --output-dir $OUTPUT_DIR ``` You can also launch an evaluation with `make evaluate`, specifying the model, task, and optionally the parallelism technique and number of GPUs. To evaluate on a single GPU: + ```shell make evaluate MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B TASK=aime24 ``` To use Data Parallelism: + ```shell make evaluate MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B TASK=aime24 PARALLEL=data NUM_GPUS=8 ``` To use Tensor Parallelism: + ```shell make evaluate MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B TASK=aime24 PARALLEL=tensor NUM_GPUS=8 ``` -## Reproducing Deepseek's evaluation results on MATH-500 -We are able to reproduce Deepseek's reported results on the MATH-500 Benchmark: -| Model | MATH-500 (HF lighteval) | MATH-500 (DeepSeek Reported) | -| :-------------------------- | :-------: | :----------------------------: | -| DeepSeek-R1-Distill-Qwen-1.5B | 81.6 | 83.9 | -| DeepSeek-R1-Distill-Qwen-7B | 91.8 | 92.8 | -| DeepSeek-R1-Distill-Qwen-14B | 94.2 | 93.9 | -| DeepSeek-R1-Distill-Qwen-32B | 95.0 | 94.3 | -| DeepSeek-R1-Distill-Llama-8B | 85.8 | 89.1 | -| DeepSeek-R1-Distill-Llama-70B | 93.4 | 94.5 | +## Reproducing Deepseek's evaluation results + +> [!NOTE] +> The DeepSeek-R1 paper uses sampling with a temperature of 0.6, a top-p value of 0.95, and 64 responses per query to estimate `pass@1`. Below, we report the results from greedy decoding, which likely explains the small 1-3σ discrepancies between our results and theirs. + +### MATH-500 + +We are able to reproduce Deepseek's reported results on the MATH-500 benchmark within ~1-3 standard deviations: +| Model | MATH-500 (🤗 LightEval) | MATH-500 (DeepSeek Reported) | +|:------------------------------|:-----------------------:|:----------------------------:| +| DeepSeek-R1-Distill-Qwen-1.5B | 81.2 | 83.9 | +| DeepSeek-R1-Distill-Qwen-7B | 91.8 | 92.8 | +| DeepSeek-R1-Distill-Qwen-14B | 94.2 | 93.9 | +| DeepSeek-R1-Distill-Qwen-32B | 95.0 | 94.3 | +| DeepSeek-R1-Distill-Llama-8B | 85.4 | 89.1 | +| DeepSeek-R1-Distill-Llama-70B | 93.4 | 94.5 | To reproduce these results use the following command: + +```shell +NUM_GPUS=1 # Set to 8 for 32B and 70B models +MODEL=deepseek-ai/{model_name} +MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilisation=0.8,tensor_parallel_size=$NUM_GPUS" +OUTPUT_DIR=data/evals/$MODEL + +lighteval vllm $MODEL_ARGS "custom|math_500|0|0" \ + --custom-tasks src/open_r1/evaluate.py \ + --use-chat-template \ + --output-dir $OUTPUT_DIR +``` + +Alternatively, you can launch Slurm jobs as follows: + ```shell -sbatch slurm/evaluate.slurm deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B math_500 -sbatch slurm/evaluate.slurm deepseek-ai/DeepSeek-R1-Distill-Qwen-7B math_500 -sbatch slurm/evaluate.slurm deepseek-ai/DeepSeek-R1-Distill-Qwen-14B math_500 -sbatch slurm/evaluate.slurm deepseek-ai/DeepSeek-R1-Distill-Qwen-32B math_500 tp -sbatch slurm/evaluate.slurm deepseek-ai/DeepSeek-R1-Distill-Llama-8B math_500 -sbatch slurm/evaluate.slurm deepseek-ai/DeepSeek-R1-Distill-Llama-70B math_500 tp +python scripts/run_benchmarks.py --model-id={model_id} --benchmarks math_500 ``` +### GPQA Diamond + +We are able to reproduce Deepseek's reported results on the GPQA Diamond benchmark within ~1-3 standard deviations: + +| Model | GPQA Diamond (🤗 LightEval) | GPQA Diamond (DeepSeek Reported) | +|:------------------------------|:---------------------------:|:--------------------------------:| +| DeepSeek-R1-Distill-Qwen-1.5B | 33.3 | 33.8 | +| DeepSeek-R1-Distill-Qwen-7B | 48.4 | 49.1 | +| DeepSeek-R1-Distill-Qwen-14B | 55.6 | 59.1 | +| DeepSeek-R1-Distill-Qwen-32B | 58.6 | 62.1 | +| DeepSeek-R1-Distill-Llama-8B | 51.0 | 49.0 | +| DeepSeek-R1-Distill-Llama-70B | 65.2 | 65.2 | + +To reproduce these results use the following command: + +```shell +NUM_GPUS=1 # Set to 8 for 32B and 70B models +MODEL=deepseek-ai/{model_name} +MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilisation=0.8,tensor_parallel_size=$NUM_GPUS" +OUTPUT_DIR=data/evals/$MODEL + +lighteval vllm $MODEL_ARGS "custom|gpqa:diamond|0|0" \ + --custom-tasks src/open_r1/evaluate.py \ + --use-chat-template \ + --output-dir $OUTPUT_DIR +``` +```shell +python scripts/run_benchmarks.py --model-id={model_id} --benchmarks gpqa +``` ## Data generation diff --git a/logs/.gitkeep b/logs/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/setup.py b/setup.py index 7ac01918..a2dd93ad 100644 --- a/setup.py +++ b/setup.py @@ -53,17 +53,17 @@ "huggingface-hub[cli]>=0.19.2,<1.0", "isort>=5.12.0", "liger_kernel==0.5.2", - "lighteval @ git+https://github.com/huggingface/lighteval.git@0e462692436e1f0575bdb4c6ef63453ad9bde7d4#egg=lighteval[math]", - "math-verify>=0.3.3", # Used for math verification in grpo + "lighteval @ git+https://github.com/huggingface/lighteval.git@86f62259f105ae164f655e0b91c92a823a742724#egg=lighteval[math]", + "math-verify==0.5.2", # Used for math verification in grpo "packaging>=23.0", "parameterized>=0.9.0", "pytest", "safetensors>=0.3.3", "sentencepiece>=0.1.99", - "torch>=2.5.1", + "torch==2.5.1", "transformers @ git+https://github.com/huggingface/transformers.git@main", "trl @ git+https://github.com/huggingface/trl.git@main", - "vllm>=0.7.1", + "vllm==0.7.1", "wandb>=0.19.1", ] diff --git a/slurm/eval_callback.slurm b/slurm/eval_callback.slurm deleted file mode 100644 index bec49ab9..00000000 --- a/slurm/eval_callback.slurm +++ /dev/null @@ -1,75 +0,0 @@ -#!/bin/bash -#SBATCH --ntasks-per-node=1 -#SBATCH --gres=gpu:8 -#SBATCH --partition=hopper-prod -#SBATCH --output=./logs/evaluate/%x-%j.out -#SBATCH --err=./logs/evaluate/%x-%j.err -#SBATCH --requeue - -set -x -e -source ~/.bashrc -source openr1/bin/activate -TASK_NAME=$1 -TASKS=$2 -MODEL_ID=$3 -MODEL_REVISION=$4 -# Optional args -[ -z "$5"] && TENSOR_PARALLEL=False || TENSOR_PARALLEL=$5 -[ -z "$6"] && TRUST_REMOTE_CODE=False || TRUST_REMOTE_CODE=$6 -# $7 is reserved for system_prompt, see line 51 -NUM_GPUS=$(nvidia-smi -L | wc -l) - -# Set Whether to use tensor parallelism or data parallelism -if [ "$TENSOR_PARALLEL" = "True" ]; then - # use TP to shard model across NUM_GPUS - export VLLM_WORKER_MULTIPROC_METHOD=spawn - MODEL_ARGS="pretrained=$MODEL_ID,revision=$MODEL_REVISION,trust_remote_code=$TRUST_REMOTE_CODE,dtype=bfloat16,tensor_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilisation=0.8" -else - MODEL_ARGS="pretrained=$MODEL_ID,revision=$MODEL_REVISION,trust_remote_code=$TRUST_REMOTE_CODE,dtype=bfloat16,data_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilisation=0.8" -fi - -LM_EVAL_REPO_ID="open-r1/open-r1-eval-leaderboard" -MODEL_NAME=$(echo $MODEL_ID | sed 's/\//_/g') # replaces / with _ -DETAILS_REPO_ID="open-r1/details-$MODEL_NAME" -OUTPUT_DIR="eval_results/$MODEL_ID/$MODEL_REVISION/$TASK_NAME" -# We need this flag since we run this script from training jobs that use DeepSpeed and the env vars get progated which causes errors during evaluation -ACCELERATE_USE_DEEPSPEED=false -# Enable fast downloads -HF_HUB_ENABLE_HF_TRANSFER=1 - -echo "Running lighteval script ..." -echo "Eval results will be saved to $OUTPUT_DIR" -# Check if "custom" is a substring of TASKS -if [[ $TASKS == *"custom"* ]]; then - echo "Custom task detected. Running custom task evaluation script ..." - lighteval vllm $MODEL_ARGS $TASKS \ - --custom-tasks "src/open_r1/evaluate.py" \ - --use-chat-template \ - --output-dir $OUTPUT_DIR \ - --save-details \ - ${7:+--system-prompt "$7"} -else - lighteval vllm $MODEL_ARGS $TASKS \ - --use-chat-template \ - --output-dir $OUTPUT_DIR \ - --save-details \ - ${7:+--system-prompt "$7"} -fi - -OUTPUT_FILEPATHS=$(find $OUTPUT_DIR/results/ -type f \( -name "*.json" \)) -for filepath in $OUTPUT_FILEPATHS; do - echo "Uploading $filepath to Hugging Face Hub..." - filename=$(basename -- "$filepath") - huggingface-cli upload --repo-type space --private $LM_EVAL_REPO_ID $filepath $OUTPUT_DIR/$filename -done - -echo "Uploading details to Hugging Face Hub..." -DETAILS_FILEPATHS=$(find $OUTPUT_DIR/details/ -type f \( -name "*.parquet" \)) -echo "DETAILS_FILEPATHS: $DETAILS_FILEPATHS" -TIMESTAMP=$(date +"%Y-%m-%dT%H-%M-%S") -python src/open_r1/utils/upload_details.py --data_files $DETAILS_FILEPATHS --hub_repo_id $DETAILS_REPO_ID --config_name $MODEL_REVISION.$TASK_NAME.$TIMESTAMP - -echo "Cleaning up ..." -rm -rf $OUTPUT_DIR - -echo "Done!" \ No newline at end of file diff --git a/slurm/evaluate.slurm b/slurm/evaluate.slurm index 0ca4a870..47659807 100644 --- a/slurm/evaluate.slurm +++ b/slurm/evaluate.slurm @@ -1,55 +1,75 @@ #!/bin/bash -#SBATCH --job-name=open-r1-evaluate -#SBATCH --nodes=1 #SBATCH --ntasks-per-node=1 -#SBATCH --exclusive #SBATCH --gres=gpu:8 -#SBATCH --partition=hopper-prod -#SBATCH --time=01:59:00 -#SBATCH --output=./logs/evaluate/%x-%j.out -#SBATCH --err=./logs/evaluate/%x-%j.err - -# Usage: sbatch slurm/evaluate.slurm deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B aime24 +#SBATCH --partition=hopper-prod +#SBATCH --output=./logs/%x-%j.out +#SBATCH --err=./logs/%x-%j.err +#SBATCH --requeue set -x -e - source ~/.bashrc source openr1/bin/activate -module load cuda/12.1 -echo "START TIME: $(date)" -echo "PYTHON ENV: $(which python)" - +TASK_NAME=$1 +TASKS=$2 +MODEL_ID=$3 +MODEL_REVISION=$4 +# Optional args +[ -z "$5"] && TENSOR_PARALLEL=False || TENSOR_PARALLEL=$5 +[ -z "$6"] && TRUST_REMOTE_CODE=False || TRUST_REMOTE_CODE=$6 +# $7 is reserved for system_prompt, see line 51 +NUM_GPUS=$(nvidia-smi -L | wc -l) -NUM_GPUS=8 -MODEL=$1 -TASK=$2 -# Check if a third argument is passed, if it is tp then eval with tensor parallelism. Required for larger models -if [ -n "$3" ] && [ "$3" == "tp" ]; then - MODEL_ARGS="pretrained=$MODEL,dtype=float16,tensor_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilisation=0.8" +# Set Whether to use tensor parallelism or data parallelism +if [ "$TENSOR_PARALLEL" = "True" ]; then + # use TP to shard model across NUM_GPUS + export VLLM_WORKER_MULTIPROC_METHOD=spawn + MODEL_ARGS="pretrained=$MODEL_ID,revision=$MODEL_REVISION,trust_remote_code=$TRUST_REMOTE_CODE,dtype=bfloat16,tensor_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilisation=0.8" else - MODEL_ARGS="pretrained=$MODEL,dtype=float16,data_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilisation=0.8" + MODEL_ARGS="pretrained=$MODEL_ID,revision=$MODEL_REVISION,trust_remote_code=$TRUST_REMOTE_CODE,dtype=bfloat16,data_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilisation=0.8" fi -OUTPUT_DIR=data/evals/$MODEL - - -# force crashing on nccl issues like hanging broadcast -export NCCL_ASYNC_ERROR_HANDLING=1 -# export NCCL_DEBUG=INFO -# export NCCL_DEBUG_SUBSYS=COLL -# export NCCL_SOCKET_NTHREADS=1 -# export NCCL_NSOCKS_PERTHREAD=1 -# export CUDA_LAUNCH_BLOCKING=1 -# Specific configuration optimized for the Hugging Face Compute Cluster -# Be ye warned this may not work on other clusters! -module load cuda/12.1 +LM_EVAL_REPO_ID="open-r1/open-r1-eval-leaderboard" +MODEL_NAME=$(echo $MODEL_ID | sed 's/\//_/g') # replaces / with _ +DETAILS_REPO_ID="open-r1/details-$MODEL_NAME" +OUTPUT_DIR="eval_results/$MODEL_ID/$MODEL_REVISION/$TASK_NAME" +# We need this flag since we run this script from training jobs that use DeepSpeed and the env vars get progated which causes errors during evaluation +ACCELERATE_USE_DEEPSPEED=false +# Enable fast downloads +HF_HUB_ENABLE_HF_TRANSFER=1 -lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \ - --custom-tasks src/open_r1/evaluate.py \ +echo "Running lighteval script ..." +echo "Eval results will be saved to $OUTPUT_DIR" +# Check if "custom" is a substring of TASKS +if [[ $TASKS == *"custom"* ]]; then + echo "Custom task detected. Running custom task evaluation script ..." + lighteval vllm $MODEL_ARGS $TASKS \ + --custom-tasks "src/open_r1/evaluate.py" \ --use-chat-template \ - --system-prompt="Please reason step by step, and put your final answer within \boxed{}." \ + --output-dir $OUTPUT_DIR \ --save-details \ - --output-dir $OUTPUT_DIR + ${7:+--system-prompt "$7"} +else + lighteval vllm $MODEL_ARGS $TASKS \ + --use-chat-template \ + --output-dir $OUTPUT_DIR \ + --save-details \ + ${7:+--system-prompt "$7"} +fi + +OUTPUT_FILEPATHS=$(find $OUTPUT_DIR/results/ -type f \( -name "*.json" \)) +for filepath in $OUTPUT_FILEPATHS; do + echo "Uploading $filepath to Hugging Face Hub..." + filename=$(basename -- "$filepath") + huggingface-cli upload --repo-type space --private $LM_EVAL_REPO_ID $filepath $OUTPUT_DIR/$filename +done +echo "Uploading details to Hugging Face Hub..." +DETAILS_FILEPATHS=$(find $OUTPUT_DIR/details/ -type f \( -name "*.parquet" \)) +echo "DETAILS_FILEPATHS: $DETAILS_FILEPATHS" +TIMESTAMP=$(date +"%Y-%m-%dT%H-%M-%S") +python src/open_r1/utils/upload_details.py --data_files $DETAILS_FILEPATHS --hub_repo_id $DETAILS_REPO_ID --config_name $MODEL_REVISION.$TASK_NAME.$TIMESTAMP + +echo "Cleaning up ..." +rm -rf $OUTPUT_DIR -echo "END TIME: $(date)" +echo "Done!" \ No newline at end of file diff --git a/src/open_r1/evaluate.py b/src/open_r1/evaluate.py index c800a889..0447b266 100644 --- a/src/open_r1/evaluate.py +++ b/src/open_r1/evaluate.py @@ -14,8 +14,11 @@ """Custom evaluation tasks for LightEval.""" +import random + from lighteval.metrics.dynamic_metrics import ( ExprExtractionConfig, + IndicesExtractionConfig, LatexExtractionConfig, multilingual_extractive_match_metric, ) @@ -44,6 +47,13 @@ aggregation_function=max, ) +gpqa_metric = multilingual_extractive_match_metric( + language=Language.ENGLISH, + gold_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")], + pred_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")], + precision=5, +) + def prompt_fn(line, task_name: str = None): """Assumes the model is either prompted to emit \\boxed{answer} or does so automatically""" @@ -64,6 +74,23 @@ def aime_prompt_fn(line, task_name: str = None): ) +def gpqa_prompt_fn(line, task_name: str = None): + """Prompt template adapted from simple-evals: https://github.com/openai/simple-evals/blob/83ed7640a7d9cd26849bcb3340125002ef14abbe/common.py#L14""" + gold_index = random.randint(0, 3) + choices = [line["Incorrect Answer 1"], line["Incorrect Answer 2"], line["Incorrect Answer 3"]] + choices.insert(gold_index, line["Correct Answer"]) + query_template = "Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.\n\n{Question}\n\nA) {A}\nB) {B}\nC) {C}\nD) {D}" + query = query_template.format(A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=line["Question"]) + + return Doc( + task_name=task_name, + query=query, + choices=["A", "B", "C", "D"], + gold_index=gold_index, + instruction=query, + ) + + # Define tasks aime24 = LightevalTaskConfig( name="aime24", @@ -93,11 +120,29 @@ def aime_prompt_fn(line, task_name: str = None): metric=[latex_gold_metric], version=1, ) +gpqa_diamond = LightevalTaskConfig( + name="gpqa:diamond", + suite=["custom"], + prompt_function=gpqa_prompt_fn, + hf_repo="Idavidrein/gpqa", + hf_subset="gpqa_diamond", + hf_avail_splits=["train"], + evaluation_splits=["train"], + few_shots_split=None, + few_shots_select=None, + generation_size=32768, # needed for reasoning models like R1 + metric=[gpqa_metric], + stop_sequence=[], # no stop sequence, will use eos token + trust_dataset=True, + version=1, +) + # Add tasks to the table TASKS_TABLE = [] TASKS_TABLE.append(aime24) TASKS_TABLE.append(math_500) +TASKS_TABLE.append(gpqa_diamond) # MODULE LOGIC if __name__ == "__main__": diff --git a/src/open_r1/grpo.py b/src/open_r1/grpo.py index e8c1c556..4bdc335f 100644 --- a/src/open_r1/grpo.py +++ b/src/open_r1/grpo.py @@ -68,7 +68,7 @@ def accuracy_reward(completions, solution, **kwargs): malformed_operators=False, basic_latex=True, equations=True, - boxed=True, + boxed="all", units=True, ), # Ensures that boxed is tried first diff --git a/src/open_r1/utils/evaluation.py b/src/open_r1/utils/evaluation.py index 9cbac82d..86de906d 100644 --- a/src/open_r1/utils/evaluation.py +++ b/src/open_r1/utils/evaluation.py @@ -48,6 +48,7 @@ def register_lighteval_task( register_lighteval_task(LIGHTEVAL_TASKS, "custom", "math_500", "math_500", 0) register_lighteval_task(LIGHTEVAL_TASKS, "custom", "aime24", "aime24", 0) +register_lighteval_task(LIGHTEVAL_TASKS, "custom", "gpqa", "gpqa:diamond", 0) def get_lighteval_tasks(): @@ -74,7 +75,7 @@ def run_lighteval_job( cmd_args = [ f"--gres=gpu:{num_gpus}", f"--job-name=or1_{benchmark}_{model_name.split('/')[-1]}_{model_revision}", - "slurm/eval_callback.slurm", + "slurm/evaluate.slurm", benchmark, f'"{task_list}"', model_name,