Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions skyrl-train/docs/configuration/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,8 @@ Algorithm Configuration
advantage_batch_normalize: false
value_head_prefix: "value_head"
policy_loss_type: "regular" # "regular", "dual_clip", "gspo", or customizable with PolicyLossRegistry
loss_reduction: "token_mean" # "token_mean", "sequence_mean"
loss_reduction: "token_mean" # "token_mean", "sequence_mean", "seq_mean_token_sum_norm"
grpo_norm_by_std: true # set to false to disable normalization by std in GRPO (used in Dr. GRPO)

# GAE parameters
lambd: 1.0
Expand Down Expand Up @@ -329,7 +330,13 @@ Algorithm Configuration
- ``gspo``: `Group Sequence Policy Optimization <https://arxiv.org/abs/2507.18071>`_ with sequence-level importance sampling for improved training stability. Implements "GSPO-token" variant from the paper.
- Custom policy losses can be registered with the ``PolicyLossRegistry``

- ``algorithm.loss_reduction``: Type of loss reduction to use. Options are ``token_mean`` and ``sequence_mean``. ``token_mean`` matches token-level loss introduced by `DAPO <https://dapo-sia.github.io/>`_. ``sequence_mean`` computes per-sequence avg token loss, then averages over the batch.
- ``algorithm.loss_reduction``: Type of loss reduction to use. Options include:

- ``token_mean``: computes average loss over all valid tokens in the batch. Used in `DAPO <https://dapo-sia.github.io/>`_.
- ``sequence_mean``: computes per-sequence avg token loss, then averages over the batch.
- ``seq_mean_token_sum_norm``: computes the sum of token losses for each sequence, normalizes by the max sequence length (computed as ``cfg.generator.max_input_length + cfg.generator.sampling_params.max_generate_length``), and then averages over the batch. This is used in `Dr. GRPO <https://arxiv.org/abs/2503.20783>`_.

- ``algorithm.grpo_norm_by_std``: Whether to normalize advantages by the standard deviation in GRPO. This is set to ``false`` in `Dr. GRPO <https://arxiv.org/abs/2503.20783>`_.
- ``algorithm.lambd``: Lambda parameter for GAE.
- ``algorithm.gamma``: Gamma parameter for GAE.
- ``algorithm.eps_clip_low``: Lower bound for PPO clipping.
Expand Down
61 changes: 61 additions & 0 deletions skyrl-train/examples/algorithm/drgrpo/run_drgrpo_gsm8k.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
set -x

# Colocated Dr. GRPO training+generation for Qwen2.5-1.5B-Instruct on GSM8K.

# uv run examples/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k
# export WANDB_API_KEY=<your_key_here>
# bash examples/algorithm/drgrpo/run_drgrpo_gsm8k.sh

# TODO (erictang000): add a description of the algorithm once GRPO docs are added.

DATA_DIR="$HOME/data/gsm8k"
NUM_GPUS=4
LOGGER="wandb" # change to "console" to print to stdout

# Dr. GRPO parameters

LOSS_REDUCTION="seq_mean_token_sum_norm"
GRPO_NORM_BY_STD=false
USE_KL_LOSS=false

uv run --isolated --extra vllm -m skyrl_train.entrypoints.main_base \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
trainer.algorithm.advantage_estimator="grpo" \
trainer.algorithm.loss_reduction=$LOSS_REDUCTION \
trainer.algorithm.grpo_norm_by_std=$GRPO_NORM_BY_STD \
trainer.algorithm.use_kl_loss=$USE_KL_LOSS \
trainer.policy.model.path="Qwen/Qwen2.5-1.5B-Instruct" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \
generator.num_inference_engines=$NUM_GPUS \
generator.inference_engine_tensor_parallel_size=1 \
trainer.epochs=20 \
trainer.eval_batch_size=1024 \
trainer.eval_before_train=true \
trainer.eval_interval=5 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=1024 \
trainer.policy_mini_batch_size=256 \
trainer.micro_forward_batch_size_per_gpu=64 \
trainer.micro_train_batch_size_per_gpu=64 \
trainer.ckpt_interval=10 \
trainer.max_prompt_length=512 \
generator.sampling_params.max_generate_length=1024 \
trainer.policy.optimizer_config.lr=1.0e-6 \
generator.backend=vllm \
generator.run_engines_locally=true \
generator.weight_sync_backend=nccl \
generator.async_engine=true \
generator.batched=true \
environment.env_class=gsm8k \
generator.n_samples_per_prompt=5 \
generator.gpu_memory_utilization=0.8 \
trainer.logger="$LOGGER" \
trainer.project_name="gsm8k" \
trainer.run_name="gsm8k_drgrpo" \
trainer.resume_mode=null \
trainer.ckpt_path="$HOME/ckpts/gsm8k_1.5B_ckpt" \
$@
2 changes: 1 addition & 1 deletion skyrl-train/examples/ppo/run_ppo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ uv run --isolated --extra vllm -m skyrl_train.entrypoints.main_base \
generator.gpu_memory_utilization=0.8 \
trainer.logger="wandb" \
trainer.project_name="gsm8k" \
trainer.run_name="gsm8k_test" \
trainer.run_name="gsm8k_ppo" \
trainer.resume_mode=null \
trainer.ckpt_path="$HOME/ckpts/gsm8k_1.5B_ckpt_ppo" \
trainer.eval_batch_size=1024 \
Expand Down
3 changes: 2 additions & 1 deletion skyrl-train/skyrl_train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ trainer:
advantage_batch_normalize: false
value_head_prefix: "value_head"
policy_loss_type: "regular" # "regular", "dual_clip", "gspo", or customizable with PolicyLossRegistry
loss_reduction: "token_mean" # "token_mean", "sequence_mean"
loss_reduction: "token_mean" # "token_mean", "sequence_mean", "seq_mean_token_sum_norm"
grpo_norm_by_std: true # set to false to disable normalization by std in GRPO
# GAE parameters
lambd: 1.0
gamma: 1.0
Expand Down
1 change: 1 addition & 0 deletions skyrl-train/skyrl_train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,7 @@ def compute_advantages_and_returns(self, data: TrainingInputBatch) -> TrainingIn
values=data["values"],
gamma=self.cfg.trainer.algorithm.gamma,
lambd=self.cfg.trainer.algorithm.lambd,
grpo_norm_by_std=self.cfg.trainer.algorithm.grpo_norm_by_std,
)
data["returns"] = returns
data["advantages"] = advantages
Expand Down
56 changes: 46 additions & 10 deletions skyrl-train/skyrl_train/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from enum import StrEnum
from typing import Callable, List, Tuple, Union, Optional, Literal
from functools import wraps

import torch
import numpy as np

Expand Down Expand Up @@ -160,6 +159,28 @@ def masked_whiten(values, mask, shift_mean=True):
return whitened


def ppo_critic_loss(
values: torch.Tensor,
old_values: torch.Tensor,
returns: torch.Tensor,
config: DictConfig,
loss_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[float]]:

if config.value_clip is not None:
values_clipped = old_values + (values - old_values).clamp(-config.value_clip, config.value_clip)
surr1 = (values_clipped - returns) ** 2
surr2 = (values - returns) ** 2
loss = torch.max(surr1, surr2)
clipfrac = masked_mean((surr1 > surr2).float(), loss_mask).mean().detach().item()
else:
clipfrac = None
loss = (values - returns) ** 2

loss = masked_mean(loss, loss_mask, dim=-1).mean()
return 0.5 * loss, clipfrac


# Shared registry actor class for both policy loss and advantage estimator registries
@ray.remote
class RegistryActor:
Expand Down Expand Up @@ -468,7 +489,8 @@ def ppo_policy_loss(
assert loss_reduction in [
"token_mean",
"sequence_mean",
], "loss_reduction must be either 'token_mean' or 'sequence_mean'"
"seq_mean_token_sum_norm",
], "loss_reduction must be either 'token_mean', 'sequence_mean', or 'seq_mean_token_sum_norm'"

ratio = (log_probs - old_log_probs).exp()
surr1 = ratio * advantages
Expand All @@ -480,7 +502,7 @@ def ppo_policy_loss(
pg_losses3 = -advantages * config.clip_ratio_c
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
loss = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
loss = reduce_loss(loss, loss_mask, loss_reduction)
loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len)
return loss, clip_ratio


Expand Down Expand Up @@ -538,20 +560,34 @@ def gspo_policy_loss(
# Compute clipping ratio for monitoring
clip_ratio = masked_mean((-surr2 > -surr1).float(), loss_mask).mean().detach().item()

loss = reduce_loss(loss, loss_mask, loss_reduction)
loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len)

return loss, clip_ratio


def reduce_loss(
loss: torch.Tensor, loss_mask: Optional[torch.Tensor], loss_reduction: Literal["token_mean", "sequence_mean"]
loss: torch.Tensor,
loss_mask: Optional[torch.Tensor],
loss_reduction: Literal["token_mean", "sequence_mean", "seq_mean_token_sum_norm"],
max_seq_len: Optional[int] = None,
) -> torch.Tensor:
if loss_reduction == "token_mean":
# sum over *all* valid tokens, divide by total valid-token count
loss = masked_mean(loss, loss_mask)
elif loss_reduction == "sequence_mean":
# per-sequence token-mean (dim=-1), then batch-mean
loss = masked_mean(loss, loss_mask, dim=-1).mean()
elif loss_reduction == "seq_mean_token_sum_norm":
# per-sequence token-sum, normalized by the max sequence length, then batch mean
# this is the Dr. GRPO loss reduction to avoid length bias by normalizing by a constant
assert max_seq_len is not None, "max_seq_len must be provided for seq_mean_token_sum_norm loss reduction"
# NOTE: max_seq_len is computed as cfg.generator.max_input_length + cfg.generator.sampling_params.max_generate_length by default
if loss_mask is not None:
seq_losses = torch.sum(loss * loss_mask, dim=-1) / max_seq_len
else:
# If no mask, assume all tokens are valid
seq_losses = torch.sum(loss, dim=-1) / max_seq_len
loss = torch.mean(seq_losses)
else:
raise ValueError(f"Invalid loss reduction type: {loss_reduction}")
return loss
Expand Down Expand Up @@ -594,7 +630,7 @@ def compute_grpo_outcome_advantage(
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
norm_adv_by_std_in_grpo: bool = True,
grpo_norm_by_std: bool = True,
**kwargs,
):
"""
Expand All @@ -605,7 +641,7 @@ def compute_grpo_outcome_advantage(
- response_mask: Float[torch.Tensor, "batch_size seqlen"]
- index: np.ndarray (batch_size)
- epsilon: float
- norm_adv_by_std_in_grpo: bool
- grpo_norm_by_std: bool

Returns:
- advantages: Float[torch.Tensor, "batch_size seqlen"]
Expand All @@ -632,7 +668,7 @@ def compute_grpo_outcome_advantage(
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
if norm_adv_by_std_in_grpo:
if grpo_norm_by_std:
scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon)
else:
scores[i] = scores[i] - id2mean[index[i]]
Expand All @@ -647,7 +683,7 @@ def compute_advantages_and_returns(
index: np.ndarray,
adv_estimator: AdvantageEstimator,
values: Optional[torch.Tensor] = None,
norm_adv_by_std_in_grpo: bool = True,
grpo_norm_by_std: bool = True,
gamma=1.0,
lambd=1.0,
):
Expand All @@ -658,7 +694,7 @@ def compute_advantages_and_returns(
response_mask=response_mask,
index=index,
values=values,
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
grpo_norm_by_std=grpo_norm_by_std,
gamma=gamma,
lambd=lambd,
)
15 changes: 13 additions & 2 deletions skyrl-train/skyrl_train/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import ray
import torch
from loguru import logger
from omegaconf.dictconfig import DictConfig
from omegaconf import DictConfig, OmegaConf
from ray.util.placement_group import placement_group, PlacementGroupSchedulingStrategy, PlacementGroup
from skyrl_train.utils.ppo_utils import AdvantageEstimatorRegistry, PolicyLossRegistry, sync_registries

Expand Down Expand Up @@ -195,7 +195,18 @@ def validate_cfg(cfg: DictConfig):
assert cfg.trainer.algorithm.loss_reduction in (
"token_mean",
"sequence_mean",
), f"invalid loss_reduction: {cfg.trainer.algorithm.loss_reduction}. Must be one of `['token_mean', 'sequence_mean']`"
"seq_mean_token_sum_norm",
), f"invalid loss_reduction: {cfg.trainer.algorithm.loss_reduction}. Must be one of `['token_mean', 'sequence_mean', 'seq_mean_token_sum_norm']`"

# add field to algorithm config needed for loss functions
# create a new config to make it modifiable
algorithm_config = OmegaConf.create(cfg.trainer.algorithm)
# NOTE (erictang000): this is the max sequence length including the prompt, since max response length
# per batch can be variable based on the prompt length. This is used to normalize the loss for
# seq_mean_token_sum_norm loss reduction. Potentially revisit this if we update to use a
# fixed max response budget.
algorithm_config.max_seq_len = cfg.generator.max_input_length + cfg.generator.sampling_params.max_generate_length
cfg.trainer.algorithm = algorithm_config

if cfg.trainer.strategy == "deepspeed" and not (
cfg.trainer.policy.optimizer_config.offload_after_step
Expand Down
11 changes: 0 additions & 11 deletions skyrl-train/skyrl_train/workers/deepspeed/deepspeed_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,20 @@
import ray
import torch
import torch.distributed
from functools import partial
from loguru import logger
from transformers import AutoModel

from transformers.trainer import get_scheduler


from skyrl_train.models import get_llm_for_sequence_regression, Actor
from skyrl_train.distributed.deepspeed_strategy import DeepspeedStrategy
from skyrl_train.utils import get_physical_gpu_id
from skyrl_train.utils.utils import str_to_torch_dtype
from skyrl_train.utils.ppo_utils import PolicyLossRegistry
from skyrl_train.workers.worker import (
PolicyWorkerBase,
CriticWorkerBase,
RewardWorkerBase,
RefWorkerBase,
ValueLoss,
)


Expand Down Expand Up @@ -94,10 +90,6 @@ def init_model(self, model_id_or_path):
(actor, actor_optim, actor_scheduler),
)

# set ppo loss function
policy_loss_func = PolicyLossRegistry.get(self.cfg.trainer.algorithm.policy_loss_type)
self.actor_loss_fn = partial(policy_loss_func, config=self.cfg.trainer.algorithm)

self.use_cuda_ipc = False
if self.cfg.generator.weight_sync_backend == "nccl" and self.cfg.trainer.placement.colocate_all:
self.use_cuda_ipc = True
Expand Down Expand Up @@ -287,9 +279,6 @@ def init_model(self, model_id_or_path):
(critic, critic_optim, critic_scheduler),
)

# set ppo loss function
self.critic_loss_fn = ValueLoss(self.cfg.trainer.algorithm.value_clip)


class DeepSpeedRewardWorkerBase(RewardWorkerBase):
def offload_to_cpu(self, pin_memory=True, non_blocking=True):
Expand Down
10 changes: 0 additions & 10 deletions skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import ray
import torch
import torch.distributed
from functools import partial
from transformers import AutoModel, AutoConfig
from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
Expand All @@ -18,15 +17,13 @@
from skyrl_train.models import Actor, get_llm_for_sequence_regression
from skyrl_train.distributed.fsdp_strategy import FSDPStrategy
from skyrl_train.utils import get_physical_gpu_id, str_to_torch_dtype
from skyrl_train.utils.ppo_utils import PolicyLossRegistry
from skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch
from skyrl_train.distributed.fsdp_utils import fsdp_version, get_init_weight_context_manager
from skyrl_train.workers.worker import (
PolicyWorkerBase,
CriticWorkerBase,
RewardWorkerBase,
RefWorkerBase,
ValueLoss,
)


Expand Down Expand Up @@ -88,10 +85,6 @@ def init_model(self, model_path):
self.optimizer is not None and self.scheduler is not None
), "FSDP preparation should create optimizer and scheduler"

# set ppo loss function
policy_loss_func = PolicyLossRegistry.get(self.cfg.trainer.algorithm.policy_loss_type)
self.actor_loss_fn = partial(policy_loss_func, config=self.cfg.trainer.algorithm)

self.use_cuda_ipc = False
if self.cfg.generator.weight_sync_backend == "nccl" and self.cfg.trainer.placement.colocate_all:
self.use_cuda_ipc = True
Expand Down Expand Up @@ -267,9 +260,6 @@ def init_model(self, model_path):
)
assert self.optimizer is not None

# set ppo loss function
self.critic_loss_fn = ValueLoss(self.cfg.trainer.algorithm.value_clip)

def forward(
self,
data: TrainingInputBatch,
Expand Down
Loading