diff --git a/examples/asymre_gsm8k/README.md b/examples/asymre_gsm8k/README.md new file mode 100644 index 0000000000..61a5cb9bf9 --- /dev/null +++ b/examples/asymre_gsm8k/README.md @@ -0,0 +1,7 @@ +# Example: AsymRE on GSM8k dataset + +This example shows the usage of [AsymRE](https://arxiv.org/abs/2506.20520) on the [GSM8k dataset](https://huggingface.co/datasets/openai/gsm8k). + +For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_reasoning_basic.md). + +The config files are located in [`gsm8k.yaml`](gsm8k.yaml) and [`train_gsm8k.yaml`](train_gsm8k.yaml). diff --git a/examples/asymre_gsm8k/gsm8k.yaml b/examples/asymre_gsm8k/gsm8k.yaml new file mode 100644 index 0000000000..d825d8f89f --- /dev/null +++ b/examples/asymre_gsm8k/gsm8k.yaml @@ -0,0 +1,66 @@ +project: sync_offset_0_sync20 +name: asymre-gsm8k_shift-0.1 +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +model: + model_path: /PATH/TO/MODEL/ + max_response_tokens: 1024 + max_model_len: 1280 +algorithm: + algorithm_type: asymre + policy_loss_fn_args: + tau: 0 + advantage_fn_args: + baseline_shift: -0.1 + repeat_times: 8 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_steps: 80 + batch_size: 96 + max_retry_times: 3 + max_retry_interval: 1 + explorer_input: + taskset: + name: gsm8k + storage_type: file + path: /PATH/TO/DATASET/ + split: train + format: + prompt_key: question + response_key: answer + rollout_args: + temperature: 1.0 + eval_tasksets: + - name: gsm8k-eval + storage_type: file + path: /PATH/TO/DATASET/ + split: test + format: + prompt_key: question + response_key: answer + default_workflow_type: math_workflow + trainer_input: + experience_buffer: + name: gsm8k_buffer + storage_type: queue +explorer: + eval_interval: 20 + runner_num: 64 + rollout_model: + engine_type: vllm_async + engine_num: 4 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 +synchronizer: + sync_method: nccl + sync_interval: 20 + sync_timeout: 1200 + sync_offset: 0 +trainer: + trainer_type: verl + trainer_config_path: examples/asymre_gsm8k/train_gsm8k.yaml + save_interval: 100 diff --git a/examples/asymre_gsm8k/train_gsm8k.yaml b/examples/asymre_gsm8k/train_gsm8k.yaml new file mode 100644 index 0000000000..00b6ffb9f6 --- /dev/null +++ b/examples/asymre_gsm8k/train_gsm8k.yaml @@ -0,0 +1,48 @@ +actor_rollout_ref: + hybrid_engine: True + model: + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True # False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_micro_batch_size_per_gpu: 8 + use_dynamic_bsz: True # False + ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-6 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size_per_gpu: 16 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + +trainer: + balance_batch: True + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + val_before_train: False diff --git a/examples/asymre_math/README.md b/examples/asymre_math/README.md new file mode 100644 index 0000000000..746b4e8695 --- /dev/null +++ b/examples/asymre_math/README.md @@ -0,0 +1,7 @@ +# Example: AsymRE on MATH dataset + +This example shows the usage of [AsymRE](https://arxiv.org/abs/2506.20520) on the [MATH dataset](https://huggingface.co/datasets/nlile/hendrycks-MATH-benchmark). + +For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_reasoning_basic.md). + +The config files are located in [`math.yaml`](math.yaml) and [`train_math.yaml`](train_math.yaml). diff --git a/examples/asymre_math/math.yaml b/examples/asymre_math/math.yaml new file mode 100644 index 0000000000..b8472f9e2d --- /dev/null +++ b/examples/asymre_math/math.yaml @@ -0,0 +1,77 @@ +# Configuration file for the AsymRE Math project. +# REINFORCE for off-Policy Reinforcement Learning: Balancing positive and negative rewards +# https://arxiv.org/abs/2506.20520. +project: "Trinity-RFT-MATH" +name: asymre_math +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +model: + model_path: /PATH/TO/MODEL/ # the path to your model + max_response_tokens: 1024 + max_model_len: 1280 +algorithm: + algorithm_type: asymre + policy_loss_fn_args: + tau: 0 + advantage_fn_args: + baseline_shift: -0.1 # Baseline shift for the AsymRE + repeat_times: 8 +cluster: + node_num: 1 + gpu_per_node: 8 + +buffer: + total_steps: 2000 # Exactly 2000 training steps as desired + batch_size: 16 # 128 trajectories per gradient step, batch_size is the number of tasks per batch + max_retry_times: 3 + max_retry_interval: 1 + explorer_input: + taskset: + name: math + storage_type: file + path: /PATH/TO/DATASET/ + format: + prompt_key: 'problem' + response_key: 'solution' + rollout_args: + temperature: 1.0 + top_p: 1.0 + logprobs: 0 + eval_tasksets: + - name: math + storage_type: file + path: /PATH/TO/DATASET/ + split: 'test' + format: + prompt_key: 'problem' + response_key: 'solution' + rollout_args: + temperature: 0.1 + top_p: 0.95 + default_workflow_type: math_boxed_workflow + default_reward_fn_type: math_boxed_reward + trainer_input: + experience_buffer: + name: math_buffer + storage_type: queue + # path: 'sqlite:///math.db' +explorer: + eval_interval: 250 + runner_num: 32 + rollout_model: + engine_type: vllm_async + engine_num: 4 + tensor_parallel_size: 1 # Each engine uses 1 GPU + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + max_prompt_tokens: 1024 + max_response_tokens: 2048 + seed: 42 +synchronizer: + sync_method: 'nccl' + sync_interval: 250 + sync_timeout: 3600 # Increased from 2000 to 3600 seconds (1 hour) +trainer: + trainer_type: 'verl' + trainer_config_path: 'examples/asymre_math/train_math.yaml' + save_interval: 500 diff --git a/examples/asymre_math/train_math.yaml b/examples/asymre_math/train_math.yaml new file mode 100644 index 0000000000..6b11b5bfa5 --- /dev/null +++ b/examples/asymre_math/train_math.yaml @@ -0,0 +1,48 @@ +actor_rollout_ref: + hybrid_engine: True + model: + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True # False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_micro_batch_size_per_gpu: 8 + use_dynamic_bsz: True # False + ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 6e-8 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size_per_gpu: 16 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + +trainer: + balance_batch: True + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + val_before_train: False diff --git a/tests/utils/eval_utils_test.py b/tests/utils/eval_utils_test.py index 4cb137bf73..533c04b836 100644 --- a/tests/utils/eval_utils_test.py +++ b/tests/utils/eval_utils_test.py @@ -3,10 +3,102 @@ import unittest -from trinity.utils.eval_utils import is_equiv +from trinity.utils.eval_utils import compute_score, is_equiv from trinity.utils.math_eval_utils import extract_answer, verify_math_answer +class TestComputeScore(unittest.TestCase): + """ + A suite of unit tests for the compute_score function. + """ + + def test_both_boxed_and_equivalent(self): + """ + Tests the case where both solution and ground truth have equivalent boxed answers. + Expected score: 1.0 + """ + solution = "The final answer is \\boxed{42}" + truth = "The correct result is \\boxed{42}" + self.assertEqual(compute_score(solution, truth), 1.0) + + def test_solution_raw_and_ground_truth_boxed_equivalent(self): + """ + Tests the case where the solution is a raw string and the ground truth is boxed, but they are equivalent. + Expected score: 1.0 + """ + solution = "The answer is \\boxed{42}" + truth = "The answer is \\boxed{42}" + self.assertEqual(compute_score(solution, truth), 1.0) + + def test_solution_boxed_truth_raw_and_equivalent(self): + """ + Tests the case where the solution is boxed and the ground truth is a raw, equivalent string. + Expected score: 1.0 + """ + solution = "Let's see, the result is \\boxed{100}" + truth = "100" + self.assertEqual(compute_score(solution, truth), 1.0) + + def test_both_boxed_and_not_equivalent(self): + """ + Tests the case where both have boxed answers, but they are not equivalent. + Expected score: 0.0 + """ + solution = "I think the answer is \\boxed{-1}" + truth = "The answer is \\boxed{1}" + self.assertEqual(compute_score(solution, truth), 0.0) + + def test_solution_boxed_truth_raw_and_not_equivalent(self): + """ + Tests the case where the solution is boxed and the ground truth is a raw, non-equivalent string. + Expected score: 0.0 + """ + solution = "The answer is \\boxed{apple}" + truth = "orange" + self.assertEqual(compute_score(solution, truth), 0.0) + + def test_solution_not_boxed(self): + """ + Tests the case where the solution string does not contain a boxed answer. + Expected score: 0.0, regardless of the ground truth. + """ + solution = "The answer is 42, but I'm not boxing it." + truth_boxed = "The answer is \\boxed{42}" + truth_raw = "42" + self.assertEqual(compute_score(solution, truth_boxed), 0.0) + self.assertEqual(compute_score(solution, truth_raw), 0.0) + + def test_empty_solution_string(self): + """ + Tests behavior with an empty solution string. + Expected score: 0.0 + """ + solution = "" + truth = "\\boxed{10}" + self.assertEqual(compute_score(solution, truth), 0.0) + + def test_empty_ground_truth(self): + """ + Tests behavior with an empty ground truth string. + Expected score: 0.0 unless the boxed answer is also empty. + """ + solution_correct = "The answer is \\boxed{}" + solution_incorrect = "The answer is \\boxed{1}" + truth = "" + self.assertEqual(compute_score(solution_correct, truth), 1.0) + self.assertEqual(compute_score(solution_incorrect, truth), 0.0) + + def test_multiple_boxed_answers_in_solution(self): + """ + Tests that only the *last* boxed answer in the solution is used for scoring. + """ + solution = "First I thought it was \\boxed{A}, but then I realized it is \\boxed{B}" + truth_correct = "\\boxed{B}" + truth_incorrect = "\\boxed{A}" + self.assertEqual(compute_score(solution, truth_correct), 1.0) + self.assertEqual(compute_score(solution, truth_incorrect), 0.0) + + class TestMathEvalUtils(unittest.TestCase): def test_extract_answer(self): test_cases = [ diff --git a/trinity/algorithm/advantage_fn/__init__.py b/trinity/algorithm/advantage_fn/__init__.py index 7e99c2cb5c..f7e1932d98 100644 --- a/trinity/algorithm/advantage_fn/__init__.py +++ b/trinity/algorithm/advantage_fn/__init__.py @@ -3,6 +3,7 @@ AdvantageFn, GroupAdvantage, ) +from trinity.algorithm.advantage_fn.asymre_advantage import ASYMREAdvantageFn from trinity.algorithm.advantage_fn.grpo_advantage import ( GRPOAdvantageFn, GRPOGroupedAdvantage, @@ -34,4 +35,5 @@ "RLOOAdvantageFn", "OPMDAdvantageFn", "OPMDGroupAdvantage", + "ASYMREAdvantageFn", ] diff --git a/trinity/algorithm/advantage_fn/asymre_advantage.py b/trinity/algorithm/advantage_fn/asymre_advantage.py new file mode 100644 index 0000000000..bd97f328ca --- /dev/null +++ b/trinity/algorithm/advantage_fn/asymre_advantage.py @@ -0,0 +1,121 @@ +"""AsymRE advantage computation""" + +from collections import defaultdict +from typing import Dict, List, Tuple + +import torch +from verl import DataProto + +from trinity.algorithm.advantage_fn.advantage_fn import ( + ADVANTAGE_FN, + AdvantageFn, + GroupAdvantage, +) +from trinity.common.experience import Experience, group_by +from trinity.utils.annotations import Deprecated + + +@Deprecated +@ADVANTAGE_FN.register_module("asymre_verl") +class ASYMREAdvantageFn(AdvantageFn): + """AsymRE advantage computation""" + + def __init__( + self, + baseline_shift: float = -0.1, + ) -> None: + self.baseline_shift = baseline_shift + + def __call__( + self, + exps: DataProto, + **kwargs, + ) -> Tuple[DataProto, Dict]: + """Modified from compute_grpo_outcome_advantage + + Compute advantage for AsymRE, operating only on Outcome reward + (with only one scalar reward for each response). + + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + eos_mask: `(torch.Tensor)` + shape: (bs, response_length) + scores: `(torch.Tensor)` + shape: (bs, response_length) + """ + token_level_rewards = exps.batch["token_level_rewards"] + eos_mask = exps.batch["response_mask"] + + index = exps.non_tensor_batch["uid"] + baseline_shift = self.baseline_shift + + response_length = token_level_rewards.shape[-1] + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2baseline = {} + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2baseline[idx] = torch.tensor(0.0) + # TODO: consider id2baseline[idx] = id2score[idx] (so that this sample won't take effect?) + elif len(id2score[idx]) > 1: + id2baseline[idx] = torch.mean(torch.tensor(id2score[idx])) + baseline_shift + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + scores[i] = scores[i] - id2baseline[index[i]] + scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask + + exps.batch["advantages"] = scores + exps.batch["returns"] = scores + + metrics = { + # TODO: add meaningful metrics + } + + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "baseline_shift": -0.1, + } + + +@ADVANTAGE_FN.register_module("asymre") +class ASYMREGroupAdvantage(GroupAdvantage): + """asymre Group Advantage computation""" + + def __init__(self, baseline_shift: float = -0.1, **kwargs) -> None: + super().__init__(**kwargs) + self.baseline_shift = baseline_shift + + def group_experiences(self, exps): + return group_by(exps, id_type="task") + + def calculate_group_advantage( + self, group_id: str, exps: List[Experience] + ) -> Tuple[List[Experience], Dict]: + with torch.no_grad(): + if len(exps) == 1: + group_baseline = torch.tensor(0.0) + else: + group_rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32) + group_baseline = torch.mean(group_rewards) + self.baseline_shift + for exp in exps: + score = exp.reward - group_baseline + exp.advantages = score * exp.action_mask + exp.returns = exp.advantages.clone() + metrics = { + "group_baseline": group_baseline.item(), + } + return exps, metrics + + @classmethod + def default_args(cls) -> dict: + return {"baseline_shift": -0.1} diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index 4457603b5b..fcedf20267 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -132,6 +132,29 @@ def default_config(cls) -> Dict: } +@ALGORITHM_TYPE.register_module("asymre") +class AsymREAlgorithm(AlgorithmType): + """AsymRE algorithm.""" + + use_critic: bool = False + use_reference: bool = False + compute_advantage_in_trainer: bool = False + can_balance_batch: bool = True + schema: type = ExperienceModel + + @classmethod + def default_config(cls) -> Dict: + return { + "repeat_times": 2, + "sample_strategy": "warmup", + "policy_loss_fn": "opmd", + "advantage_fn": "asymre", + "kl_penalty_fn": "none", + "kl_loss_fn": "none", + "entropy_loss_fn": "none", + } + + @ALGORITHM_TYPE.register_module("dpo") class DPOAlgorithm(AlgorithmType): """DPO algorithm.""" diff --git a/trinity/utils/eval_utils.py b/trinity/utils/eval_utils.py index 3308825e81..b6fa7070f1 100644 --- a/trinity/utils/eval_utils.py +++ b/trinity/utils/eval_utils.py @@ -108,10 +108,22 @@ def compute_score(solution_str, ground_truth) -> float: retval = 0.0 try: string_in_last_boxed = last_boxed_only_string(solution_str) + original_ground_truth = ground_truth + boxed_ground_truth = last_boxed_only_string(ground_truth) + # Determine if ground_truth was raw (had boxed content) or already processed + ground_truth_was_raw = boxed_ground_truth is not None if string_in_last_boxed is not None: answer = remove_boxed(string_in_last_boxed) + if ground_truth_was_raw: + # Ground truth had boxed content - remove it + ground_truth = remove_boxed(boxed_ground_truth) + else: + # Ground truth had no boxed content - use as is + ground_truth = original_ground_truth if is_equiv(answer, ground_truth): retval = 1.0 + # logger.warning(answer, " ", ground_truth, " ", retval) + except Exception as e: print(e)