diff --git a/examples/asymre_gsm8k/gsm8k.yaml b/examples/asymre_gsm8k/gsm8k.yaml index 3eeab4d3f4..254b5987f6 100644 --- a/examples/asymre_gsm8k/gsm8k.yaml +++ b/examples/asymre_gsm8k/gsm8k.yaml @@ -1,6 +1,10 @@ -project: sync_offset_0_sync20 -name: asymre-gsm8k_shift-0.1 -checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +# Configuration file for the AsymRE GSM8k project. +# REINFORCE for off-Policy Reinforcement Learning: Balancing positive and negative rewards +# https://arxiv.org/abs/2506.20520. + +project: "Trinity-RFT-GSM8K" +name: asymre_gsm8k +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ model: model_path: /PATH/TO/MODEL/ max_response_tokens: 1024 diff --git a/examples/sppo_gsm8k/README.md b/examples/sppo_gsm8k/README.md new file mode 100644 index 0000000000..2f51a1b2be --- /dev/null +++ b/examples/sppo_gsm8k/README.md @@ -0,0 +1,7 @@ +# Example: sPPO on GSM8k dataset + +This example shows the usage of [sPPO](https://arxiv.org/abs/2108.05828) on the [GSM8k dataset](https://huggingface.co/datasets/openai/gsm8k). + +For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_reasoning_basic.md). + +The config files are located in [`gsm8k.yaml`](gsm8k.yaml) and [`train_gsm8k.yaml`](train_gsm8k.yaml). diff --git a/examples/sppo_gsm8k/gsm8k.yaml b/examples/sppo_gsm8k/gsm8k.yaml new file mode 100644 index 0000000000..d72459afe3 --- /dev/null +++ b/examples/sppo_gsm8k/gsm8k.yaml @@ -0,0 +1,68 @@ +# Configuration file for the sPPO GSM8k project. +# A general class of surrogate functions for stable and efficient reinforcement learning +# https://arxiv.org/abs/2108.05828. + +project: "Trinity-RFT-GSM8K" +name: sppo_gsm8k +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +model: + model_path: /PATH/TO/MODEL/ + max_response_tokens: 1024 + max_model_len: 1280 +algorithm: + algorithm_type: sppo + policy_loss_fn_args: + epsilon: 0.1 + repeat_times: 8 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_steps: 100 + batch_size: 96 + max_retry_times: 3 + max_retry_interval: 1 + explorer_input: + taskset: + name: gsm8k + storage_type: file + path: /PATH/TO/DATASET/ + split: train + format: + prompt_key: question + response_key: answer + rollout_args: + temperature: 1.0 + eval_tasksets: + - name: gsm8k-eval + storage_type: file + path: /PATH/TO/DATASET/ + split: test + format: + prompt_key: question + response_key: answer + default_workflow_type: math_workflow + trainer_input: + experience_buffer: + name: gsm8k_buffer + storage_type: queue +explorer: + eval_interval: 20 + runner_num: 64 + rollout_model: + engine_type: vllm_async + engine_num: 4 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 +synchronizer: + sync_method: nccl + sync_interval: 20 + sync_timeout: 1200 + sync_offset: 0 +trainer: + trainer_type: verl + trainer_config_path: examples/sppo_gsm8k/train_gsm8k.yaml + save_interval: 100 diff --git a/examples/sppo_gsm8k/train_gsm8k.yaml b/examples/sppo_gsm8k/train_gsm8k.yaml new file mode 100644 index 0000000000..00b6ffb9f6 --- /dev/null +++ b/examples/sppo_gsm8k/train_gsm8k.yaml @@ -0,0 +1,48 @@ +actor_rollout_ref: + hybrid_engine: True + model: + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True # False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_micro_batch_size_per_gpu: 8 + use_dynamic_bsz: True # False + ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-6 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size_per_gpu: 16 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + +trainer: + balance_batch: True + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + val_before_train: False diff --git a/trinity/algorithm/advantage_fn/asymre_advantage.py b/trinity/algorithm/advantage_fn/asymre_advantage.py index bd97f328ca..81aa6659ff 100644 --- a/trinity/algorithm/advantage_fn/asymre_advantage.py +++ b/trinity/algorithm/advantage_fn/asymre_advantage.py @@ -113,6 +113,7 @@ def calculate_group_advantage( exp.returns = exp.advantages.clone() metrics = { "group_baseline": group_baseline.item(), + "reward_mean": group_baseline.item() - self.baseline_shift, } return exps, metrics diff --git a/trinity/algorithm/advantage_fn/opmd_advantage.py b/trinity/algorithm/advantage_fn/opmd_advantage.py index ab00b6cc3c..2707e11e0b 100644 --- a/trinity/algorithm/advantage_fn/opmd_advantage.py +++ b/trinity/algorithm/advantage_fn/opmd_advantage.py @@ -136,6 +136,7 @@ def calculate_group_advantage( exp.returns = exp.advantages.clone() metrics = { "group_baseline": group_baseline.item(), + "reward_mean": torch.mean(group_rewards).item(), } return exps, metrics diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index e3fac8d1e2..c987bcf608 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -269,3 +269,26 @@ def default_config(cls) -> Dict: "kl_loss_fn": "none", "entropy_loss_fn": "none", } + + +@ALGORITHM_TYPE.register_module("sppo") +class sPPOAlgorithm(AlgorithmType): + """sPPO Algorithm.""" + + use_critic: bool = False + use_reference: bool = False + compute_advantage_in_trainer: bool = False + can_balance_batch: bool = True + schema: str = "experience" + + @classmethod + def default_config(cls) -> Dict: + return { + "repeat_times": 2, + "sample_strategy": "warmup", + "policy_loss_fn": "sppo", + "advantage_fn": "opmd", + "kl_penalty_fn": "none", + "kl_loss_fn": "none", + "entropy_loss_fn": "none", + } diff --git a/trinity/algorithm/policy_loss_fn/__init__.py b/trinity/algorithm/policy_loss_fn/__init__.py index 7416e94d5b..fd604f7f33 100644 --- a/trinity/algorithm/policy_loss_fn/__init__.py +++ b/trinity/algorithm/policy_loss_fn/__init__.py @@ -10,6 +10,7 @@ from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn +from trinity.algorithm.policy_loss_fn.sppo_loss_fn import sPPOPolicyLossFn __all__ = [ "POLICY_LOSS_FN", @@ -23,4 +24,5 @@ "MIXCHORDPolicyLossFn", "SFTISLossFn", "SFTPhiLossFn", + "sPPOPolicyLossFn", ] diff --git a/trinity/algorithm/policy_loss_fn/sppo_loss_fn.py b/trinity/algorithm/policy_loss_fn/sppo_loss_fn.py new file mode 100644 index 0000000000..3dd248b6c2 --- /dev/null +++ b/trinity/algorithm/policy_loss_fn/sppo_loss_fn.py @@ -0,0 +1,54 @@ +"""sPPO-token policy loss function. +Relevant paper: https://arxiv.org/abs/2108.05828. +""" + +from typing import Dict, Tuple + +import torch + +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.utils import masked_mean + + +@POLICY_LOSS_FN.register_module("sppo") +class sPPOPolicyLossFn(PolicyLossFn): + def __init__( + self, + backend: str = "verl", + epsilon: float = 0.3, + ) -> None: + super().__init__(backend=backend) + self.epsilon = epsilon + + def __call__( # type: ignore + self, + logprob: torch.Tensor, # [batch_size, seq_len] + old_logprob: torch.Tensor, # [batch_size, seq_len] + action_mask: torch.Tensor, # [batch_size, seq_len] + advantages: torch.Tensor, # [batch_size, seq_len] + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + """Calculate sPPO loss. + The formula is as follows: + advantages*log(clip(ratio, 1/(1+epsilon), 1+epsilon)) + ratio = exp(logprob - old_logprob) + """ + # + # token-wise + ratio = torch.exp(logprob - old_logprob).detach() + is_in_range = (ratio >= (1 / (1 + self.epsilon))) * (ratio <= (1 + self.epsilon)) + is_clipped_mask = ~is_in_range + pg_losses = -advantages * (logprob - old_logprob) * is_in_range.float() + pg_loss = masked_mean(pg_losses, action_mask) + pg_clipfrac = masked_mean(is_clipped_mask.float(), action_mask) + metrics = { + "pg_clipfrac": pg_clipfrac.item(), + "pg_loss": pg_loss.detach().item(), + } + return pg_loss, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "epsilon": 0.3, + }