Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions examples/asymre_gsm8k/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Example: AsymRE on GSM8k dataset

This example shows the usage of [AsymRE](https://arxiv.org/abs/2506.20520) on the [GSM8k dataset](https://huggingface.co/datasets/openai/gsm8k).

For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_reasoning_basic.md).

The config files are located in [`gsm8k.yaml`](gsm8k.yaml) and [`train_gsm8k.yaml`](train_gsm8k.yaml).
66 changes: 66 additions & 0 deletions examples/asymre_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
project: sync_offset_0_sync20
name: asymre-gsm8k_shift-0.1
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
model:
model_path: /PATH/TO/MODEL/
max_response_tokens: 1024
max_model_len: 1280
algorithm:
algorithm_type: asymre
policy_loss_fn_args:
tau: 0
advantage_fn_args:
baseline_shift: -0.1
repeat_times: 8
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_steps: 80
batch_size: 96
max_retry_times: 3
max_retry_interval: 1
explorer_input:
taskset:
name: gsm8k
storage_type: file
path: /PATH/TO/DATASET/
split: train
format:
prompt_key: question
response_key: answer
rollout_args:
temperature: 1.0
eval_tasksets:
- name: gsm8k-eval
storage_type: file
path: /PATH/TO/DATASET/
split: test
format:
prompt_key: question
response_key: answer
default_workflow_type: math_workflow
trainer_input:
experience_buffer:
name: gsm8k_buffer
storage_type: queue
explorer:
eval_interval: 20
runner_num: 64
rollout_model:
engine_type: vllm_async
engine_num: 4
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
synchronizer:
sync_method: nccl
sync_interval: 20
sync_timeout: 1200
sync_offset: 0
trainer:
trainer_type: verl
trainer_config_path: examples/asymre_gsm8k/train_gsm8k.yaml
save_interval: 100
48 changes: 48 additions & 0 deletions examples/asymre_gsm8k/train_gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
actor_rollout_ref:
hybrid_engine: True
model:
external_lib: null
override_config: { }
enable_gradient_checkpointing: True
use_remove_padding: True # False
actor:
strategy: fsdp # This is for backward-compatibility
ppo_micro_batch_size_per_gpu: 8
use_dynamic_bsz: True # False
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
optim:
lr: 1e-6
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
# min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
param_offload: False
optimizer_offload: False
fsdp_size: -1
ref:
fsdp_config:
param_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
log_prob_micro_batch_size_per_gpu: 16
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size

trainer:
balance_batch: True
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
default_hdfs_dir: null
remove_previous_ckpt_in_save: False
del_local_ckpt_after_load: False
val_before_train: False
7 changes: 7 additions & 0 deletions examples/asymre_math/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Example: AsymRE on MATH dataset

This example shows the usage of [AsymRE](https://arxiv.org/abs/2506.20520) on the [MATH dataset](https://huggingface.co/datasets/nlile/hendrycks-MATH-benchmark).

For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_reasoning_basic.md).

The config files are located in [`math.yaml`](math.yaml) and [`train_math.yaml`](train_math.yaml).
77 changes: 77 additions & 0 deletions examples/asymre_math/math.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Configuration file for the AsymRE Math project.
# REINFORCE for off-Policy Reinforcement Learning: Balancing positive and negative rewards
# https://arxiv.org/abs/2506.20520.
project: "Trinity-RFT-MATH"
name: asymre_math
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
model:
model_path: /PATH/TO/MODEL/ # the path to your model
max_response_tokens: 1024
max_model_len: 1280
algorithm:
algorithm_type: asymre
policy_loss_fn_args:
tau: 0
advantage_fn_args:
baseline_shift: -0.1 # Baseline shift for the AsymRE
repeat_times: 8
cluster:
node_num: 1
gpu_per_node: 8

buffer:
total_steps: 2000 # Exactly 2000 training steps as desired
batch_size: 16 # 128 trajectories per gradient step, batch_size is the number of tasks per batch
max_retry_times: 3
max_retry_interval: 1
explorer_input:
taskset:
name: math
storage_type: file
path: /PATH/TO/DATASET/
format:
prompt_key: 'problem'
response_key: 'solution'
rollout_args:
temperature: 1.0
top_p: 1.0
logprobs: 0
eval_tasksets:
- name: math
storage_type: file
path: /PATH/TO/DATASET/
split: 'test'
format:
prompt_key: 'problem'
response_key: 'solution'
rollout_args:
temperature: 0.1
top_p: 0.95
default_workflow_type: math_boxed_workflow
default_reward_fn_type: math_boxed_reward
trainer_input:
experience_buffer:
name: math_buffer
storage_type: queue
# path: 'sqlite:///math.db'
explorer:
eval_interval: 250
runner_num: 32
rollout_model:
engine_type: vllm_async
engine_num: 4
tensor_parallel_size: 1 # Each engine uses 1 GPU
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
max_prompt_tokens: 1024
max_response_tokens: 2048
seed: 42
synchronizer:
sync_method: 'nccl'
sync_interval: 250
sync_timeout: 3600 # Increased from 2000 to 3600 seconds (1 hour)
trainer:
trainer_type: 'verl'
trainer_config_path: 'examples/asymre_math/train_math.yaml'
save_interval: 500
48 changes: 48 additions & 0 deletions examples/asymre_math/train_math.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
actor_rollout_ref:
hybrid_engine: True
model:
external_lib: null
override_config: { }
enable_gradient_checkpointing: True
use_remove_padding: True # False
actor:
strategy: fsdp # This is for backward-compatibility
ppo_micro_batch_size_per_gpu: 8
use_dynamic_bsz: True # False
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
optim:
lr: 6e-8
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
# min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
param_offload: False
optimizer_offload: False
fsdp_size: -1
ref:
fsdp_config:
param_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
log_prob_micro_batch_size_per_gpu: 16
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size

trainer:
balance_batch: True
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
default_hdfs_dir: null
remove_previous_ckpt_in_save: False
del_local_ckpt_after_load: False
val_before_train: False
94 changes: 93 additions & 1 deletion tests/utils/eval_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,102 @@

import unittest

from trinity.utils.eval_utils import is_equiv
from trinity.utils.eval_utils import compute_score, is_equiv
from trinity.utils.math_eval_utils import extract_answer, verify_math_answer


class TestComputeScore(unittest.TestCase):
"""
A suite of unit tests for the compute_score function.
"""

def test_both_boxed_and_equivalent(self):
"""
Tests the case where both solution and ground truth have equivalent boxed answers.
Expected score: 1.0
"""
solution = "The final answer is \\boxed{42}"
truth = "The correct result is \\boxed{42}"
self.assertEqual(compute_score(solution, truth), 1.0)

def test_solution_raw_and_ground_truth_boxed_equivalent(self):
"""
Tests the case where the solution is a raw string and the ground truth is boxed, but they are equivalent.
Expected score: 1.0
"""
solution = "The answer is \\boxed{42}"
truth = "The answer is \\boxed{42}"
self.assertEqual(compute_score(solution, truth), 1.0)

def test_solution_boxed_truth_raw_and_equivalent(self):
"""
Tests the case where the solution is boxed and the ground truth is a raw, equivalent string.
Expected score: 1.0
"""
solution = "Let's see, the result is \\boxed{100}"
truth = "100"
self.assertEqual(compute_score(solution, truth), 1.0)

def test_both_boxed_and_not_equivalent(self):
"""
Tests the case where both have boxed answers, but they are not equivalent.
Expected score: 0.0
"""
solution = "I think the answer is \\boxed{-1}"
truth = "The answer is \\boxed{1}"
self.assertEqual(compute_score(solution, truth), 0.0)

def test_solution_boxed_truth_raw_and_not_equivalent(self):
"""
Tests the case where the solution is boxed and the ground truth is a raw, non-equivalent string.
Expected score: 0.0
"""
solution = "The answer is \\boxed{apple}"
truth = "orange"
self.assertEqual(compute_score(solution, truth), 0.0)

def test_solution_not_boxed(self):
"""
Tests the case where the solution string does not contain a boxed answer.
Expected score: 0.0, regardless of the ground truth.
"""
solution = "The answer is 42, but I'm not boxing it."
truth_boxed = "The answer is \\boxed{42}"
truth_raw = "42"
self.assertEqual(compute_score(solution, truth_boxed), 0.0)
self.assertEqual(compute_score(solution, truth_raw), 0.0)

def test_empty_solution_string(self):
"""
Tests behavior with an empty solution string.
Expected score: 0.0
"""
solution = ""
truth = "\\boxed{10}"
self.assertEqual(compute_score(solution, truth), 0.0)

def test_empty_ground_truth(self):
"""
Tests behavior with an empty ground truth string.
Expected score: 0.0 unless the boxed answer is also empty.
"""
solution_correct = "The answer is \\boxed{}"
solution_incorrect = "The answer is \\boxed{1}"
truth = ""
self.assertEqual(compute_score(solution_correct, truth), 1.0)
self.assertEqual(compute_score(solution_incorrect, truth), 0.0)

def test_multiple_boxed_answers_in_solution(self):
"""
Tests that only the *last* boxed answer in the solution is used for scoring.
"""
solution = "First I thought it was \\boxed{A}, but then I realized it is \\boxed{B}"
truth_correct = "\\boxed{B}"
truth_incorrect = "\\boxed{A}"
self.assertEqual(compute_score(solution, truth_correct), 1.0)
self.assertEqual(compute_score(solution, truth_incorrect), 0.0)


class TestMathEvalUtils(unittest.TestCase):
def test_extract_answer(self):
test_cases = [
Expand Down
2 changes: 2 additions & 0 deletions trinity/algorithm/advantage_fn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
AdvantageFn,
GroupAdvantage,
)
from trinity.algorithm.advantage_fn.asymre_advantage import ASYMREAdvantageFn
from trinity.algorithm.advantage_fn.grpo_advantage import (
GRPOAdvantageFn,
GRPOGroupedAdvantage,
Expand Down Expand Up @@ -34,4 +35,5 @@
"RLOOAdvantageFn",
"OPMDAdvantageFn",
"OPMDGroupAdvantage",
"ASYMREAdvantageFn",
]
Loading