Skip to content
Merged

sPPO #232

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions examples/asymre_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
project: sync_offset_0_sync20
name: asymre-gsm8k_shift-0.1
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
# Configuration file for the AsymRE GSM8k project.
# REINFORCE for off-Policy Reinforcement Learning: Balancing positive and negative rewards
# https://arxiv.org/abs/2506.20520.

project: "Trinity-RFT-GSM8K"
name: asymre_gsm8k
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
model:
model_path: /PATH/TO/MODEL/
max_response_tokens: 1024
Expand Down
7 changes: 7 additions & 0 deletions examples/sppo_gsm8k/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Example: sPPO on GSM8k dataset

This example shows the usage of [sPPO](https://arxiv.org/abs/2108.05828) on the [GSM8k dataset](https://huggingface.co/datasets/openai/gsm8k).

For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_reasoning_basic.md).

The config files are located in [`gsm8k.yaml`](gsm8k.yaml) and [`train_gsm8k.yaml`](train_gsm8k.yaml).
68 changes: 68 additions & 0 deletions examples/sppo_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Configuration file for the sPPO GSM8k project.
# A general class of surrogate functions for stable and efficient reinforcement learning
# https://arxiv.org/abs/2108.05828.

project: "Trinity-RFT-GSM8K"
name: sppo_gsm8k
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
model:
model_path: /PATH/TO/MODEL/
max_response_tokens: 1024
max_model_len: 1280
algorithm:
algorithm_type: sppo
policy_loss_fn_args:
epsilon: 0.1
repeat_times: 8
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_steps: 100
batch_size: 96
max_retry_times: 3
max_retry_interval: 1
explorer_input:
taskset:
name: gsm8k
storage_type: file
path: /PATH/TO/DATASET/
split: train
format:
prompt_key: question
response_key: answer
rollout_args:
temperature: 1.0
eval_tasksets:
- name: gsm8k-eval
storage_type: file
path: /PATH/TO/DATASET/
split: test
format:
prompt_key: question
response_key: answer
default_workflow_type: math_workflow
trainer_input:
experience_buffer:
name: gsm8k_buffer
storage_type: queue
explorer:
eval_interval: 20
runner_num: 64
rollout_model:
engine_type: vllm_async
engine_num: 4
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
synchronizer:
sync_method: nccl
sync_interval: 20
sync_timeout: 1200
sync_offset: 0
trainer:
trainer_type: verl
trainer_config_path: examples/sppo_gsm8k/train_gsm8k.yaml
save_interval: 100
48 changes: 48 additions & 0 deletions examples/sppo_gsm8k/train_gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
actor_rollout_ref:
hybrid_engine: True
model:
external_lib: null
override_config: { }
enable_gradient_checkpointing: True
use_remove_padding: True # False
actor:
strategy: fsdp # This is for backward-compatibility
ppo_micro_batch_size_per_gpu: 8
use_dynamic_bsz: True # False
ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
grad_clip: 1.0
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
optim:
lr: 1e-6
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
# min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
param_offload: False
optimizer_offload: False
fsdp_size: -1
ref:
fsdp_config:
param_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
log_prob_micro_batch_size_per_gpu: 16
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size

trainer:
balance_batch: True
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
default_hdfs_dir: null
remove_previous_ckpt_in_save: False
del_local_ckpt_after_load: False
val_before_train: False
1 change: 1 addition & 0 deletions trinity/algorithm/advantage_fn/asymre_advantage.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def calculate_group_advantage(
exp.returns = exp.advantages.clone()
metrics = {
"group_baseline": group_baseline.item(),
"reward_mean": group_baseline.item() - self.baseline_shift,
}
return exps, metrics

Expand Down
1 change: 1 addition & 0 deletions trinity/algorithm/advantage_fn/opmd_advantage.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def calculate_group_advantage(
exp.returns = exp.advantages.clone()
metrics = {
"group_baseline": group_baseline.item(),
"reward_mean": torch.mean(group_rewards).item(),
}
return exps, metrics

Expand Down
23 changes: 23 additions & 0 deletions trinity/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,26 @@ def default_config(cls) -> Dict:
"kl_loss_fn": "none",
"entropy_loss_fn": "none",
}


@ALGORITHM_TYPE.register_module("sppo")
class sPPOAlgorithm(AlgorithmType):
"""sPPO Algorithm."""

use_critic: bool = False
use_reference: bool = False
compute_advantage_in_trainer: bool = False
can_balance_batch: bool = True
schema: str = "experience"

@classmethod
def default_config(cls) -> Dict:
return {
"repeat_times": 2,
"sample_strategy": "warmup",
"policy_loss_fn": "sppo",
"advantage_fn": "opmd",
"kl_penalty_fn": "none",
"kl_loss_fn": "none",
"entropy_loss_fn": "none",
}
2 changes: 2 additions & 0 deletions trinity/algorithm/policy_loss_fn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn
from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn
from trinity.algorithm.policy_loss_fn.sppo_loss_fn import sPPOPolicyLossFn

__all__ = [
"POLICY_LOSS_FN",
Expand All @@ -23,4 +24,5 @@
"MIXCHORDPolicyLossFn",
"SFTISLossFn",
"SFTPhiLossFn",
"sPPOPolicyLossFn",
]
54 changes: 54 additions & 0 deletions trinity/algorithm/policy_loss_fn/sppo_loss_fn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""sPPO-token policy loss function.
Relevant paper: https://arxiv.org/abs/2108.05828.
"""

from typing import Dict, Tuple

import torch

from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn
from trinity.algorithm.utils import masked_mean


@POLICY_LOSS_FN.register_module("sppo")
class sPPOPolicyLossFn(PolicyLossFn):
def __init__(
self,
backend: str = "verl",
epsilon: float = 0.3,
) -> None:
super().__init__(backend=backend)
self.epsilon = epsilon

def __call__( # type: ignore
self,
logprob: torch.Tensor, # [batch_size, seq_len]
old_logprob: torch.Tensor, # [batch_size, seq_len]
action_mask: torch.Tensor, # [batch_size, seq_len]
advantages: torch.Tensor, # [batch_size, seq_len]
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
"""Calculate sPPO loss.
The formula is as follows:
advantages*log(clip(ratio, 1/(1+epsilon), 1+epsilon))
ratio = exp(logprob - old_logprob)
"""
#
# token-wise
ratio = torch.exp(logprob - old_logprob).detach()
is_in_range = (ratio >= (1 / (1 + self.epsilon))) * (ratio <= (1 + self.epsilon))
is_clipped_mask = ~is_in_range
pg_losses = -advantages * (logprob - old_logprob) * is_in_range.float()
pg_loss = masked_mean(pg_losses, action_mask)
pg_clipfrac = masked_mean(is_clipped_mask.float(), action_mask)
metrics = {
"pg_clipfrac": pg_clipfrac.item(),
"pg_loss": pg_loss.detach().item(),
}
return pg_loss, metrics

@classmethod
def default_args(cls) -> Dict:
return {
"epsilon": 0.3,
}