From f99386c7a836ec2393832a41eea231757d0b728b Mon Sep 17 00:00:00 2001 From: cr Date: Wed, 13 Aug 2025 12:19:41 -0700 Subject: [PATCH 01/10] [Algo] Implemented AsymRE. --- trinity/algorithm/advantage_fn/__init__.py | 2 + .../advantage_fn/asymre_advantage.py | 80 +++++++++++++++++++ trinity/algorithm/algorithm.py | 22 +++++ .../algorithm_config_manager.py | 1 + 4 files changed, 105 insertions(+) create mode 100644 trinity/algorithm/advantage_fn/asymre_advantage.py diff --git a/trinity/algorithm/advantage_fn/__init__.py b/trinity/algorithm/advantage_fn/__init__.py index 7e99c2cb5c..d2e34b2688 100644 --- a/trinity/algorithm/advantage_fn/__init__.py +++ b/trinity/algorithm/advantage_fn/__init__.py @@ -20,6 +20,7 @@ ) from trinity.algorithm.advantage_fn.remax_advantage import REMAXAdvantageFn from trinity.algorithm.advantage_fn.rloo_advantage import RLOOAdvantageFn +from trinity.algorithm.advantage_fn.asymre_advantage import ASYMREAdvantageFn __all__ = [ "ADVANTAGE_FN", @@ -34,4 +35,5 @@ "RLOOAdvantageFn", "OPMDAdvantageFn", "OPMDGroupAdvantage", + "ASYMREAdvantageFn" ] diff --git a/trinity/algorithm/advantage_fn/asymre_advantage.py b/trinity/algorithm/advantage_fn/asymre_advantage.py new file mode 100644 index 0000000000..f24d33c711 --- /dev/null +++ b/trinity/algorithm/advantage_fn/asymre_advantage.py @@ -0,0 +1,80 @@ +"""AsymRE advantage computation""" + +from collections import defaultdict +from typing import Dict, Tuple + +import torch +from verl import DataProto + +from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn + + +@ADVANTAGE_FN.register_module("asymre") +class ASYMREAdvantageFn(AdvantageFn): + """AsymRE advantage computation""" + + def __init__( + self, + asymre_baseline_shift: float = -0.1, + ) -> None: + self.asymre_baseline_shift = asymre_baseline_shift + + def __call__( + self, + exps: DataProto, + **kwargs, + ) -> Tuple[DataProto, Dict]: + """Modified from compute_grpo_outcome_advantage + + Compute advantage for AsymRE, operating only on Outcome reward + (with only one scalar reward for each response). + + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + eos_mask: `(torch.Tensor)` + shape: (bs, response_length) + scores: `(torch.Tensor)` + shape: (bs, response_length) + """ + token_level_rewards = exps.batch["token_level_rewards"] + eos_mask = exps.batch["response_mask"] + + index = exps.non_tensor_batch["uid"] + asymre_baseline_shift = self.asymre_baseline_shift + + response_length = token_level_rewards.shape[-1] + scores = token_level_rewards.sum(dim=-1) + + id2score = defaultdict(list) + id2baseline = {} + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2baseline[idx] = torch.tensor(0.0) + # TODO: consider id2baseline[idx] = id2score[idx] (so that this sample won't take effect?) + elif len(id2score[idx]) > 1: + id2baseline[idx] = torch.mean(torch.tensor(id2score[idx])) + asymre_baseline_shift + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + scores[i] = scores[i] - id2baseline[index[i]] + scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask + + exps.batch["advantages"] = scores + exps.batch["returns"] = scores + + metrics = { + # TODO: add meaningful metrics + } + + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "asymre_baseline_shift": -0.1, + } diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index 4457603b5b..00c99b445a 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -175,6 +175,28 @@ def check_config(cls, config: Config) -> None: config.algorithm.kl_loss_fn = "k2" logger.warning("DPO must use KL loss. Set `algorithm.kl_loss_fn` to `k2`") +@ALGORITHM_TYPE.register_module("asymre") +class AsymREAlgorithm(AlgorithmType): + """AsymRE algorithm.""" + + use_critic: bool = False + use_reference: bool = False + use_advantage: bool = True + can_balance_batch: bool = True + schema: type = ExperienceModel + + @classmethod + def default_config(cls) -> Dict: + return { + "repeat_times": 2, + "sample_strategy": "warmup", + "policy_loss_fn": "opmd", + "advantage_fn": "asymre", + "kl_penalty_fn": "none", + "kl_loss_fn": "none", + "entropy_loss_fn": "none", + } + @ALGORITHM_TYPE.register_module("mix") class MIXAlgorithm(AlgorithmType): diff --git a/trinity/manager/config_registry/algorithm_config_manager.py b/trinity/manager/config_registry/algorithm_config_manager.py index c9694dec25..bdceb2fa7b 100644 --- a/trinity/manager/config_registry/algorithm_config_manager.py +++ b/trinity/manager/config_registry/algorithm_config_manager.py @@ -4,6 +4,7 @@ ADVANTAGE_FN, GRPOAdvantageFn, OPMDAdvantageFn, + ASYMREAdvantageFn, PPOAdvantageFn, ) from trinity.algorithm.algorithm import ALGORITHM_TYPE, PPOAlgorithm From 8276346a416079a19980d83867a5a81e3a170ed1 Mon Sep 17 00:00:00 2001 From: cr Date: Wed, 13 Aug 2025 13:43:42 -0700 Subject: [PATCH 02/10] [Example] Added AsymRE on MATH. --- examples/asymre_math/README.md | 7 +++ examples/asymre_math/math.yaml | 72 ++++++++++++++++++++++++++++ examples/asymre_math/train_math.yaml | 48 +++++++++++++++++++ 3 files changed, 127 insertions(+) create mode 100644 examples/asymre_math/README.md create mode 100644 examples/asymre_math/math.yaml create mode 100644 examples/asymre_math/train_math.yaml diff --git a/examples/asymre_math/README.md b/examples/asymre_math/README.md new file mode 100644 index 0000000000..cb56bcb32a --- /dev/null +++ b/examples/asymre_math/README.md @@ -0,0 +1,7 @@ +# Example: AsymRE on MATH dataset + +This example shows the usage of AsymRE on the [MATH dataset](https://huggingface.co/datasets/nlile/hendrycks-MATH-benchmark). + +For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_reasoning_basic.md). + +The config files are located in [`math.yaml`](math.yaml) and [`train_math.yaml`](train_math.yaml). diff --git a/examples/asymre_math/math.yaml b/examples/asymre_math/math.yaml new file mode 100644 index 0000000000..48d238b258 --- /dev/null +++ b/examples/asymre_math/math.yaml @@ -0,0 +1,72 @@ +project: "Trinity-RFT-MATH" +name: asymre_math +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +model: + model_path: /PATH/TO/MODEL/ # the path to your model +algorithm: + algorithm_type: asymre + policy_loss_fn_args: + tau: 0 + advantage_fn_args: + asymre_baseline_shift: -0.1 + repeat_times: 8 +cluster: + node_num: 1 + gpu_per_node: 8 + +buffer: + total_steps: 2000 # Exactly 2000 training steps as desired + batch_size: 16 # 128 trajectories per gradient step, batch_size is the number of tasks per batch + max_retry_times: 3 + max_retry_interval: 1 + explorer_input: + taskset: + name: math + storage_type: file + path: /PATH/TO/DATASET/ + format: + prompt_key: 'problem' + response_key: 'solution' + rollout_args: + temperature: 1.0 + top_p: 1.0 + logprobs: 0 + eval_tasksets: + - name: math + storage_type: file + path: /PATH/TO/DATASET/ + split: 'test' + format: + prompt_key: 'problem' + response_key: 'solution' + rollout_args: + temperature: 0.1 + top_p: 0.95 + default_workflow_type: math_boxed_workflow + default_reward_fn_type: math_boxed_reward + trainer_input: + experience_buffer: + name: math_buffer + storage_type: queue + # path: 'sqlite:///math.db' +explorer: + eval_interval: 250 + runner_num: 32 + rollout_model: + engine_type: vllm_async + engine_num: 4 + tensor_parallel_size: 1 # Each engine uses 1 GPU + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + max_prompt_tokens: 1024 + max_response_tokens: 2048 + seed: 42 +synchronizer: + sync_method: 'nccl' + sync_interval: 500 + sync_timeout: 3600 # Increased from 2000 to 3600 seconds (1 hour) +trainer: + trainer_type: 'verl' + trainer_config_path: 'examples/asymre_math/train_math.yaml' + save_interval: 500 diff --git a/examples/asymre_math/train_math.yaml b/examples/asymre_math/train_math.yaml new file mode 100644 index 0000000000..735d18c2b2 --- /dev/null +++ b/examples/asymre_math/train_math.yaml @@ -0,0 +1,48 @@ +actor_rollout_ref: + hybrid_engine: True + model: + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True # False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_micro_batch_size_per_gpu: 8 + use_dynamic_bsz: True # False + ppo_max_token_len_per_gpu: 8192 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-6 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size_per_gpu: 16 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + +trainer: + balance_batch: True + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + val_before_train: False From 5979497435845e6f97863f55d78feb05bb2b7b29 Mon Sep 17 00:00:00 2001 From: cr Date: Wed, 13 Aug 2025 13:44:55 -0700 Subject: [PATCH 03/10] [Util] Extended the compute_score to support both the raw and processed inputs. --- examples/asymre_math/math.yaml | 8 ++++---- trinity/algorithm/advantage_fn/__init__.py | 7 +++++-- trinity/algorithm/advantage_fn/asymre_advantage.py | 6 ++++-- trinity/algorithm/algorithm.py | 3 ++- .../config_registry/algorithm_config_manager.py | 1 - trinity/utils/eval_utils.py | 12 ++++++++++++ 6 files changed, 27 insertions(+), 10 deletions(-) diff --git a/examples/asymre_math/math.yaml b/examples/asymre_math/math.yaml index 48d238b258..5e6c081644 100644 --- a/examples/asymre_math/math.yaml +++ b/examples/asymre_math/math.yaml @@ -13,7 +13,7 @@ algorithm: cluster: node_num: 1 gpu_per_node: 8 - + buffer: total_steps: 2000 # Exactly 2000 training steps as desired batch_size: 16 # 128 trajectories per gradient step, batch_size is the number of tasks per batch @@ -54,10 +54,10 @@ explorer: runner_num: 32 rollout_model: engine_type: vllm_async - engine_num: 4 + engine_num: 4 tensor_parallel_size: 1 # Each engine uses 1 GPU - enable_prefix_caching: false - enforce_eager: true + enable_prefix_caching: false + enforce_eager: true dtype: bfloat16 max_prompt_tokens: 1024 max_response_tokens: 2048 diff --git a/trinity/algorithm/advantage_fn/__init__.py b/trinity/algorithm/advantage_fn/__init__.py index d2e34b2688..c2453107c1 100644 --- a/trinity/algorithm/advantage_fn/__init__.py +++ b/trinity/algorithm/advantage_fn/__init__.py @@ -1,8 +1,12 @@ +from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.algorithm.advantage_fn.grpo_advantage import GRPOAdvantageFn +from trinity.algorithm.advantage_fn.opmd_advantage import OPMDAdvantageFn from trinity.algorithm.advantage_fn.advantage_fn import ( ADVANTAGE_FN, AdvantageFn, GroupAdvantage, ) +from trinity.algorithm.advantage_fn.asymre_advantage import ASYMREAdvantageFn from trinity.algorithm.advantage_fn.grpo_advantage import ( GRPOAdvantageFn, GRPOGroupedAdvantage, @@ -20,7 +24,6 @@ ) from trinity.algorithm.advantage_fn.remax_advantage import REMAXAdvantageFn from trinity.algorithm.advantage_fn.rloo_advantage import RLOOAdvantageFn -from trinity.algorithm.advantage_fn.asymre_advantage import ASYMREAdvantageFn __all__ = [ "ADVANTAGE_FN", @@ -35,5 +38,5 @@ "RLOOAdvantageFn", "OPMDAdvantageFn", "OPMDGroupAdvantage", - "ASYMREAdvantageFn" + "ASYMREAdvantageFn", ] diff --git a/trinity/algorithm/advantage_fn/asymre_advantage.py b/trinity/algorithm/advantage_fn/asymre_advantage.py index f24d33c711..0ac5b71370 100644 --- a/trinity/algorithm/advantage_fn/asymre_advantage.py +++ b/trinity/algorithm/advantage_fn/asymre_advantage.py @@ -38,7 +38,7 @@ def __call__( """ token_level_rewards = exps.batch["token_level_rewards"] eos_mask = exps.batch["response_mask"] - + index = exps.non_tensor_batch["uid"] asymre_baseline_shift = self.asymre_baseline_shift @@ -57,7 +57,9 @@ def __call__( id2baseline[idx] = torch.tensor(0.0) # TODO: consider id2baseline[idx] = id2score[idx] (so that this sample won't take effect?) elif len(id2score[idx]) > 1: - id2baseline[idx] = torch.mean(torch.tensor(id2score[idx])) + asymre_baseline_shift + id2baseline[idx] = ( + torch.mean(torch.tensor(id2score[idx])) + asymre_baseline_shift + ) else: raise ValueError(f"no score in prompt index: {idx}") for i in range(bsz): diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index 00c99b445a..582c059c86 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -175,6 +175,7 @@ def check_config(cls, config: Config) -> None: config.algorithm.kl_loss_fn = "k2" logger.warning("DPO must use KL loss. Set `algorithm.kl_loss_fn` to `k2`") + @ALGORITHM_TYPE.register_module("asymre") class AsymREAlgorithm(AlgorithmType): """AsymRE algorithm.""" @@ -196,7 +197,7 @@ def default_config(cls) -> Dict: "kl_loss_fn": "none", "entropy_loss_fn": "none", } - + @ALGORITHM_TYPE.register_module("mix") class MIXAlgorithm(AlgorithmType): diff --git a/trinity/manager/config_registry/algorithm_config_manager.py b/trinity/manager/config_registry/algorithm_config_manager.py index bdceb2fa7b..c9694dec25 100644 --- a/trinity/manager/config_registry/algorithm_config_manager.py +++ b/trinity/manager/config_registry/algorithm_config_manager.py @@ -4,7 +4,6 @@ ADVANTAGE_FN, GRPOAdvantageFn, OPMDAdvantageFn, - ASYMREAdvantageFn, PPOAdvantageFn, ) from trinity.algorithm.algorithm import ALGORITHM_TYPE, PPOAlgorithm diff --git a/trinity/utils/eval_utils.py b/trinity/utils/eval_utils.py index 3308825e81..b6fa7070f1 100644 --- a/trinity/utils/eval_utils.py +++ b/trinity/utils/eval_utils.py @@ -108,10 +108,22 @@ def compute_score(solution_str, ground_truth) -> float: retval = 0.0 try: string_in_last_boxed = last_boxed_only_string(solution_str) + original_ground_truth = ground_truth + boxed_ground_truth = last_boxed_only_string(ground_truth) + # Determine if ground_truth was raw (had boxed content) or already processed + ground_truth_was_raw = boxed_ground_truth is not None if string_in_last_boxed is not None: answer = remove_boxed(string_in_last_boxed) + if ground_truth_was_raw: + # Ground truth had boxed content - remove it + ground_truth = remove_boxed(boxed_ground_truth) + else: + # Ground truth had no boxed content - use as is + ground_truth = original_ground_truth if is_equiv(answer, ground_truth): retval = 1.0 + # logger.warning(answer, " ", ground_truth, " ", retval) + except Exception as e: print(e) From 7c961b918c9456a08e4ea29c92a57ddb39beb689 Mon Sep 17 00:00:00 2001 From: cr Date: Wed, 13 Aug 2025 17:13:46 -0700 Subject: [PATCH 04/10] [Algo] Adapted to the new version. --- trinity/algorithm/algorithm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index 582c059c86..e3f91d2906 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -182,7 +182,7 @@ class AsymREAlgorithm(AlgorithmType): use_critic: bool = False use_reference: bool = False - use_advantage: bool = True + compute_advantage_in_trainer: bool = True can_balance_batch: bool = True schema: type = ExperienceModel From 25019996d9c35719c51f52501c610b42d0c3bb80 Mon Sep 17 00:00:00 2001 From: cr Date: Thu, 14 Aug 2025 11:13:23 -0700 Subject: [PATCH 05/10] [Test] Added unittest for the compute_score func. --- tests/utils/eval_utils_test.py | 93 +++++++++++++++++++++++++++++++++- 1 file changed, 92 insertions(+), 1 deletion(-) diff --git a/tests/utils/eval_utils_test.py b/tests/utils/eval_utils_test.py index 4cb137bf73..cc878a2cd1 100644 --- a/tests/utils/eval_utils_test.py +++ b/tests/utils/eval_utils_test.py @@ -3,9 +3,100 @@ import unittest -from trinity.utils.eval_utils import is_equiv +from trinity.utils.eval_utils import is_equiv, compute_score from trinity.utils.math_eval_utils import extract_answer, verify_math_answer +class TestComputeScore(unittest.TestCase): + """ + A suite of unit tests for the compute_score function. + """ + + def test_both_boxed_and_equivalent(self): + """ + Tests the case where both solution and ground truth have equivalent boxed answers. + Expected score: 1.0 + """ + solution = "The final answer is \\boxed{42}" + truth = "The correct result is \\boxed{42}" + self.assertEqual(compute_score(solution, truth), 1.0) + + def test_solution_raw_and_ground_truth_boxed_equivalent(self): + """ + Tests the case where the solution is a raw string and the ground truth is boxed, but they are equivalent. + Expected score: 1.0 + """ + solution = "The answer is \\boxed{42}" + truth = "The answer is \\boxed{42}" + self.assertEqual(compute_score(solution, truth), 1.0) + + def test_solution_boxed_truth_raw_and_equivalent(self): + """ + Tests the case where the solution is boxed and the ground truth is a raw, equivalent string. + Expected score: 1.0 + """ + solution = "Let's see, the result is \\boxed{100}" + truth = "100" + self.assertEqual(compute_score(solution, truth), 1.0) + + def test_both_boxed_and_not_equivalent(self): + """ + Tests the case where both have boxed answers, but they are not equivalent. + Expected score: 0.0 + """ + solution = "I think the answer is \\boxed{-1}" + truth = "The answer is \\boxed{1}" + self.assertEqual(compute_score(solution, truth), 0.0) + + def test_solution_boxed_truth_raw_and_not_equivalent(self): + """ + Tests the case where the solution is boxed and the ground truth is a raw, non-equivalent string. + Expected score: 0.0 + """ + solution = "The answer is \\boxed{apple}" + truth = "orange" + self.assertEqual(compute_score(solution, truth), 0.0) + + def test_solution_not_boxed(self): + """ + Tests the case where the solution string does not contain a boxed answer. + Expected score: 0.0, regardless of the ground truth. + """ + solution = "The answer is 42, but I'm not boxing it." + truth_boxed = "The answer is \\boxed{42}" + truth_raw = "42" + self.assertEqual(compute_score(solution, truth_boxed), 0.0) + self.assertEqual(compute_score(solution, truth_raw), 0.0) + + def test_empty_solution_string(self): + """ + Tests behavior with an empty solution string. + Expected score: 0.0 + """ + solution = "" + truth = "\\boxed{10}" + self.assertEqual(compute_score(solution, truth), 0.0) + + def test_empty_ground_truth(self): + """ + Tests behavior with an empty ground truth string. + Expected score: 0.0 unless the boxed answer is also empty. + """ + solution_correct = "The answer is \\boxed{}" + solution_incorrect = "The answer is \\boxed{1}" + truth = "" + self.assertEqual(compute_score(solution_correct, truth), 1.0) + self.assertEqual(compute_score(solution_incorrect, truth), 0.0) + + def test_multiple_boxed_answers_in_solution(self): + """ + Tests that only the *last* boxed answer in the solution is used for scoring. + """ + solution = "First I thought it was \\boxed{A}, but then I realized it is \\boxed{B}" + truth_correct = "\\boxed{B}" + truth_incorrect = "\\boxed{A}" + self.assertEqual(compute_score(solution, truth_correct), 1.0) + self.assertEqual(compute_score(solution, truth_incorrect), 0.0) + class TestMathEvalUtils(unittest.TestCase): def test_extract_answer(self): From 9b9da856d296895b7f5d1aa60c842b7f0cc42296 Mon Sep 17 00:00:00 2001 From: cr Date: Thu, 14 Aug 2025 11:41:22 -0700 Subject: [PATCH 06/10] [Algo] Delegate group advantage calculation from Trainer to Explorer. --- examples/asymre_math/math.yaml | 4 ++-- trinity/algorithm/algorithm.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/asymre_math/math.yaml b/examples/asymre_math/math.yaml index 5e6c081644..c64bb58d05 100644 --- a/examples/asymre_math/math.yaml +++ b/examples/asymre_math/math.yaml @@ -7,8 +7,8 @@ algorithm: algorithm_type: asymre policy_loss_fn_args: tau: 0 - advantage_fn_args: - asymre_baseline_shift: -0.1 + add_strategy_args: + baseline_shift: -0.1 # Baseline shift for the AsymREAddStrategy repeat_times: 8 cluster: node_num: 1 diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index e3f91d2906..12786280af 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -182,7 +182,7 @@ class AsymREAlgorithm(AlgorithmType): use_critic: bool = False use_reference: bool = False - compute_advantage_in_trainer: bool = True + compute_advantage_in_trainer: bool = False can_balance_batch: bool = True schema: type = ExperienceModel @@ -192,7 +192,7 @@ def default_config(cls) -> Dict: "repeat_times": 2, "sample_strategy": "warmup", "policy_loss_fn": "opmd", - "advantage_fn": "asymre", + "add_strategy": "asymre", "kl_penalty_fn": "none", "kl_loss_fn": "none", "entropy_loss_fn": "none", From f0fb43678a49c4740a63cce27e1fb13827e7eaa1 Mon Sep 17 00:00:00 2001 From: cr Date: Thu, 14 Aug 2025 11:49:20 -0700 Subject: [PATCH 07/10] [Format] --- tests/utils/eval_utils_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/utils/eval_utils_test.py b/tests/utils/eval_utils_test.py index cc878a2cd1..533c04b836 100644 --- a/tests/utils/eval_utils_test.py +++ b/tests/utils/eval_utils_test.py @@ -3,9 +3,10 @@ import unittest -from trinity.utils.eval_utils import is_equiv, compute_score +from trinity.utils.eval_utils import compute_score, is_equiv from trinity.utils.math_eval_utils import extract_answer, verify_math_answer + class TestComputeScore(unittest.TestCase): """ A suite of unit tests for the compute_score function. @@ -96,7 +97,7 @@ def test_multiple_boxed_answers_in_solution(self): truth_incorrect = "\\boxed{A}" self.assertEqual(compute_score(solution, truth_correct), 1.0) self.assertEqual(compute_score(solution, truth_incorrect), 0.0) - + class TestMathEvalUtils(unittest.TestCase): def test_extract_answer(self): From 47a09bbac95afda80f51eb8b23d8657d091293e5 Mon Sep 17 00:00:00 2001 From: cr Date: Mon, 18 Aug 2025 12:16:00 -0700 Subject: [PATCH 08/10] [doc] Added references. --- examples/asymre_math/README.md | 2 +- examples/asymre_math/math.yaml | 5 ++++- examples/asymre_math/train_math.yaml | 4 ++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/asymre_math/README.md b/examples/asymre_math/README.md index cb56bcb32a..746b4e8695 100644 --- a/examples/asymre_math/README.md +++ b/examples/asymre_math/README.md @@ -1,6 +1,6 @@ # Example: AsymRE on MATH dataset -This example shows the usage of AsymRE on the [MATH dataset](https://huggingface.co/datasets/nlile/hendrycks-MATH-benchmark). +This example shows the usage of [AsymRE](https://arxiv.org/abs/2506.20520) on the [MATH dataset](https://huggingface.co/datasets/nlile/hendrycks-MATH-benchmark). For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_reasoning_basic.md). diff --git a/examples/asymre_math/math.yaml b/examples/asymre_math/math.yaml index c64bb58d05..6ad28dddfa 100644 --- a/examples/asymre_math/math.yaml +++ b/examples/asymre_math/math.yaml @@ -1,3 +1,6 @@ +# Configuration file for the AsymRE Math project. +# REINFORCE for off-Policy Reinforcement Learning: Balancing positive and negative rewards +# https://arxiv.org/abs/2506.20520. project: "Trinity-RFT-MATH" name: asymre_math checkpoint_root_dir: /PATH/TO/CHECKPOINT/ @@ -64,7 +67,7 @@ explorer: seed: 42 synchronizer: sync_method: 'nccl' - sync_interval: 500 + sync_interval: 250 sync_timeout: 3600 # Increased from 2000 to 3600 seconds (1 hour) trainer: trainer_type: 'verl' diff --git a/examples/asymre_math/train_math.yaml b/examples/asymre_math/train_math.yaml index 735d18c2b2..6b11b5bfa5 100644 --- a/examples/asymre_math/train_math.yaml +++ b/examples/asymre_math/train_math.yaml @@ -9,13 +9,13 @@ actor_rollout_ref: strategy: fsdp # This is for backward-compatibility ppo_micro_batch_size_per_gpu: 8 use_dynamic_bsz: True # False - ppo_max_token_len_per_gpu: 8192 # n * ${data.max_prompt_length} + ${data.max_response_length} + ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size optim: - lr: 1e-6 + lr: 6e-8 lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime # min_lr_ratio: null # only useful for warmup with cosine warmup_style: constant # select from constant/cosine From e59fdf6999b52c3714a44f4916709a7bfb0e706b Mon Sep 17 00:00:00 2001 From: cr Date: Tue, 26 Aug 2025 16:32:47 -0700 Subject: [PATCH 09/10] [Feat] Removed add_strategy. --- .../advantage_fn/asymre_advantage.py | 59 ++++++++++++++++--- trinity/algorithm/algorithm.py | 44 +++++++------- 2 files changed, 71 insertions(+), 32 deletions(-) diff --git a/trinity/algorithm/advantage_fn/asymre_advantage.py b/trinity/algorithm/advantage_fn/asymre_advantage.py index 0ac5b71370..711db6891a 100644 --- a/trinity/algorithm/advantage_fn/asymre_advantage.py +++ b/trinity/algorithm/advantage_fn/asymre_advantage.py @@ -1,23 +1,29 @@ """AsymRE advantage computation""" from collections import defaultdict -from typing import Dict, Tuple +from typing import Dict, List, Tuple import torch from verl import DataProto -from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.algorithm.advantage_fn.advantage_fn import ( + ADVANTAGE_FN, + AdvantageFn, + GroupAdvantage, +) +from trinity.common.experience import Experience, group_by +from trinity.utils.annotations import Deprecated - -@ADVANTAGE_FN.register_module("asymre") +@Deprecated +@ADVANTAGE_FN.register_module("asymre_verl") class ASYMREAdvantageFn(AdvantageFn): """AsymRE advantage computation""" def __init__( self, - asymre_baseline_shift: float = -0.1, + baseline_shift: float = -0.1, ) -> None: - self.asymre_baseline_shift = asymre_baseline_shift + self.baseline_shift = baseline_shift def __call__( self, @@ -40,7 +46,7 @@ def __call__( eos_mask = exps.batch["response_mask"] index = exps.non_tensor_batch["uid"] - asymre_baseline_shift = self.asymre_baseline_shift + baseline_shift = self.baseline_shift response_length = token_level_rewards.shape[-1] scores = token_level_rewards.sum(dim=-1) @@ -58,7 +64,7 @@ def __call__( # TODO: consider id2baseline[idx] = id2score[idx] (so that this sample won't take effect?) elif len(id2score[idx]) > 1: id2baseline[idx] = ( - torch.mean(torch.tensor(id2score[idx])) + asymre_baseline_shift + torch.mean(torch.tensor(id2score[idx])) + baseline_shift ) else: raise ValueError(f"no score in prompt index: {idx}") @@ -78,5 +84,40 @@ def __call__( @classmethod def default_args(cls) -> Dict: return { - "asymre_baseline_shift": -0.1, + "baseline_shift": -0.1, } + + + +@ADVANTAGE_FN.register_module("asymre") +class ASYMREGroupAdvantage(GroupAdvantage): + """asymre Group Advantage computation""" + + def __init__(self, baseline_shift: float = -0.1, **kwargs) -> None: + super().__init__(**kwargs) + self.baseline_shift = baseline_shift + + def group_experiences(self, exps): + return group_by(exps, id_type="task") + + def calculate_group_advantage( + self, group_id: str, exps: List[Experience] + ) -> Tuple[List[Experience], Dict]: + with torch.no_grad(): + if len(exps) == 1: + group_baseline = torch.tensor(0.0) + else: + group_rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32) + group_baseline = torch.mean(group_rewards)+ self.baseline_shift + for exp in exps: + score = exp.reward - group_baseline + exp.advantages = score * exp.action_mask + exp.returns = exp.advantages.clone() + metrics = { + "group_baseline": group_baseline.item(), + } + return exps, metrics + + @classmethod + def default_args(cls) -> dict: + return {"baseline_shift": -0.1} \ No newline at end of file diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index 12786280af..7f01d1dbba 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -130,7 +130,28 @@ def default_config(cls) -> Dict: "kl_loss_fn": "k2", "entropy_loss_fn": "default", } + +@ALGORITHM_TYPE.register_module("asymre") +class AsymREAlgorithm(AlgorithmType): + """AsymRE algorithm.""" + use_critic: bool = False + use_reference: bool = False + compute_advantage_in_trainer: bool = False + can_balance_batch: bool = True + schema: type = ExperienceModel + + @classmethod + def default_config(cls) -> Dict: + return { + "repeat_times": 2, + "sample_strategy": "warmup", + "policy_loss_fn": "opmd", + "advantage_fn": "asymre", + "kl_penalty_fn": "none", + "kl_loss_fn": "none", + "entropy_loss_fn": "none", + } @ALGORITHM_TYPE.register_module("dpo") class DPOAlgorithm(AlgorithmType): @@ -176,29 +197,6 @@ def check_config(cls, config: Config) -> None: logger.warning("DPO must use KL loss. Set `algorithm.kl_loss_fn` to `k2`") -@ALGORITHM_TYPE.register_module("asymre") -class AsymREAlgorithm(AlgorithmType): - """AsymRE algorithm.""" - - use_critic: bool = False - use_reference: bool = False - compute_advantage_in_trainer: bool = False - can_balance_batch: bool = True - schema: type = ExperienceModel - - @classmethod - def default_config(cls) -> Dict: - return { - "repeat_times": 2, - "sample_strategy": "warmup", - "policy_loss_fn": "opmd", - "add_strategy": "asymre", - "kl_penalty_fn": "none", - "kl_loss_fn": "none", - "entropy_loss_fn": "none", - } - - @ALGORITHM_TYPE.register_module("mix") class MIXAlgorithm(AlgorithmType): """MIX algorithm.""" From 26cf8c358fe625efed88a30d5ffd21e70240cd1c Mon Sep 17 00:00:00 2001 From: cr Date: Thu, 28 Aug 2025 11:42:58 +0800 Subject: [PATCH 10/10] [example] AsymRE on GSM8k. --- examples/asymre_gsm8k/README.md | 7 ++ examples/asymre_gsm8k/gsm8k.yaml | 66 +++++++++++++++++++ examples/asymre_gsm8k/train_gsm8k.yaml | 48 ++++++++++++++ examples/asymre_math/math.yaml | 6 +- trinity/algorithm/advantage_fn/__init__.py | 3 - .../advantage_fn/asymre_advantage.py | 10 ++- trinity/algorithm/algorithm.py | 4 +- 7 files changed, 132 insertions(+), 12 deletions(-) create mode 100644 examples/asymre_gsm8k/README.md create mode 100644 examples/asymre_gsm8k/gsm8k.yaml create mode 100644 examples/asymre_gsm8k/train_gsm8k.yaml diff --git a/examples/asymre_gsm8k/README.md b/examples/asymre_gsm8k/README.md new file mode 100644 index 0000000000..61a5cb9bf9 --- /dev/null +++ b/examples/asymre_gsm8k/README.md @@ -0,0 +1,7 @@ +# Example: AsymRE on GSM8k dataset + +This example shows the usage of [AsymRE](https://arxiv.org/abs/2506.20520) on the [GSM8k dataset](https://huggingface.co/datasets/openai/gsm8k). + +For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_reasoning_basic.md). + +The config files are located in [`gsm8k.yaml`](gsm8k.yaml) and [`train_gsm8k.yaml`](train_gsm8k.yaml). diff --git a/examples/asymre_gsm8k/gsm8k.yaml b/examples/asymre_gsm8k/gsm8k.yaml new file mode 100644 index 0000000000..d825d8f89f --- /dev/null +++ b/examples/asymre_gsm8k/gsm8k.yaml @@ -0,0 +1,66 @@ +project: sync_offset_0_sync20 +name: asymre-gsm8k_shift-0.1 +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +model: + model_path: /PATH/TO/MODEL/ + max_response_tokens: 1024 + max_model_len: 1280 +algorithm: + algorithm_type: asymre + policy_loss_fn_args: + tau: 0 + advantage_fn_args: + baseline_shift: -0.1 + repeat_times: 8 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_steps: 80 + batch_size: 96 + max_retry_times: 3 + max_retry_interval: 1 + explorer_input: + taskset: + name: gsm8k + storage_type: file + path: /PATH/TO/DATASET/ + split: train + format: + prompt_key: question + response_key: answer + rollout_args: + temperature: 1.0 + eval_tasksets: + - name: gsm8k-eval + storage_type: file + path: /PATH/TO/DATASET/ + split: test + format: + prompt_key: question + response_key: answer + default_workflow_type: math_workflow + trainer_input: + experience_buffer: + name: gsm8k_buffer + storage_type: queue +explorer: + eval_interval: 20 + runner_num: 64 + rollout_model: + engine_type: vllm_async + engine_num: 4 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 +synchronizer: + sync_method: nccl + sync_interval: 20 + sync_timeout: 1200 + sync_offset: 0 +trainer: + trainer_type: verl + trainer_config_path: examples/asymre_gsm8k/train_gsm8k.yaml + save_interval: 100 diff --git a/examples/asymre_gsm8k/train_gsm8k.yaml b/examples/asymre_gsm8k/train_gsm8k.yaml new file mode 100644 index 0000000000..00b6ffb9f6 --- /dev/null +++ b/examples/asymre_gsm8k/train_gsm8k.yaml @@ -0,0 +1,48 @@ +actor_rollout_ref: + hybrid_engine: True + model: + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True # False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_micro_batch_size_per_gpu: 8 + use_dynamic_bsz: True # False + ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-6 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size_per_gpu: 16 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + +trainer: + balance_batch: True + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + val_before_train: False diff --git a/examples/asymre_math/math.yaml b/examples/asymre_math/math.yaml index 6ad28dddfa..b8472f9e2d 100644 --- a/examples/asymre_math/math.yaml +++ b/examples/asymre_math/math.yaml @@ -6,12 +6,14 @@ name: asymre_math checkpoint_root_dir: /PATH/TO/CHECKPOINT/ model: model_path: /PATH/TO/MODEL/ # the path to your model + max_response_tokens: 1024 + max_model_len: 1280 algorithm: algorithm_type: asymre policy_loss_fn_args: tau: 0 - add_strategy_args: - baseline_shift: -0.1 # Baseline shift for the AsymREAddStrategy + advantage_fn_args: + baseline_shift: -0.1 # Baseline shift for the AsymRE repeat_times: 8 cluster: node_num: 1 diff --git a/trinity/algorithm/advantage_fn/__init__.py b/trinity/algorithm/advantage_fn/__init__.py index c2453107c1..f7e1932d98 100644 --- a/trinity/algorithm/advantage_fn/__init__.py +++ b/trinity/algorithm/advantage_fn/__init__.py @@ -1,6 +1,3 @@ -from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn -from trinity.algorithm.advantage_fn.grpo_advantage import GRPOAdvantageFn -from trinity.algorithm.advantage_fn.opmd_advantage import OPMDAdvantageFn from trinity.algorithm.advantage_fn.advantage_fn import ( ADVANTAGE_FN, AdvantageFn, diff --git a/trinity/algorithm/advantage_fn/asymre_advantage.py b/trinity/algorithm/advantage_fn/asymre_advantage.py index 711db6891a..bd97f328ca 100644 --- a/trinity/algorithm/advantage_fn/asymre_advantage.py +++ b/trinity/algorithm/advantage_fn/asymre_advantage.py @@ -14,6 +14,7 @@ from trinity.common.experience import Experience, group_by from trinity.utils.annotations import Deprecated + @Deprecated @ADVANTAGE_FN.register_module("asymre_verl") class ASYMREAdvantageFn(AdvantageFn): @@ -63,9 +64,7 @@ def __call__( id2baseline[idx] = torch.tensor(0.0) # TODO: consider id2baseline[idx] = id2score[idx] (so that this sample won't take effect?) elif len(id2score[idx]) > 1: - id2baseline[idx] = ( - torch.mean(torch.tensor(id2score[idx])) + baseline_shift - ) + id2baseline[idx] = torch.mean(torch.tensor(id2score[idx])) + baseline_shift else: raise ValueError(f"no score in prompt index: {idx}") for i in range(bsz): @@ -88,7 +87,6 @@ def default_args(cls) -> Dict: } - @ADVANTAGE_FN.register_module("asymre") class ASYMREGroupAdvantage(GroupAdvantage): """asymre Group Advantage computation""" @@ -108,7 +106,7 @@ def calculate_group_advantage( group_baseline = torch.tensor(0.0) else: group_rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32) - group_baseline = torch.mean(group_rewards)+ self.baseline_shift + group_baseline = torch.mean(group_rewards) + self.baseline_shift for exp in exps: score = exp.reward - group_baseline exp.advantages = score * exp.action_mask @@ -120,4 +118,4 @@ def calculate_group_advantage( @classmethod def default_args(cls) -> dict: - return {"baseline_shift": -0.1} \ No newline at end of file + return {"baseline_shift": -0.1} diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index 7f01d1dbba..fcedf20267 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -130,7 +130,8 @@ def default_config(cls) -> Dict: "kl_loss_fn": "k2", "entropy_loss_fn": "default", } - + + @ALGORITHM_TYPE.register_module("asymre") class AsymREAlgorithm(AlgorithmType): """AsymRE algorithm.""" @@ -153,6 +154,7 @@ def default_config(cls) -> Dict: "entropy_loss_fn": "none", } + @ALGORITHM_TYPE.register_module("dpo") class DPOAlgorithm(AlgorithmType): """DPO algorithm."""