diff --git a/skyrl-train/docs/configuration/config.rst b/skyrl-train/docs/configuration/config.rst index 3eeebf548f..878f020e78 100644 --- a/skyrl-train/docs/configuration/config.rst +++ b/skyrl-train/docs/configuration/config.rst @@ -291,7 +291,8 @@ Algorithm Configuration advantage_batch_normalize: false value_head_prefix: "value_head" policy_loss_type: "regular" # "regular", "dual_clip", "gspo", or customizable with PolicyLossRegistry - loss_reduction: "token_mean" # "token_mean", "sequence_mean" + loss_reduction: "token_mean" # "token_mean", "sequence_mean", "seq_mean_token_sum_norm" + grpo_norm_by_std: true # set to false to disable normalization by std in GRPO (used in Dr. GRPO) # GAE parameters lambd: 1.0 @@ -329,7 +330,13 @@ Algorithm Configuration - ``gspo``: `Group Sequence Policy Optimization `_ with sequence-level importance sampling for improved training stability. Implements "GSPO-token" variant from the paper. - Custom policy losses can be registered with the ``PolicyLossRegistry`` -- ``algorithm.loss_reduction``: Type of loss reduction to use. Options are ``token_mean`` and ``sequence_mean``. ``token_mean`` matches token-level loss introduced by `DAPO `_. ``sequence_mean`` computes per-sequence avg token loss, then averages over the batch. +- ``algorithm.loss_reduction``: Type of loss reduction to use. Options include: + + - ``token_mean``: computes average loss over all valid tokens in the batch. Used in `DAPO `_. + - ``sequence_mean``: computes per-sequence avg token loss, then averages over the batch. + - ``seq_mean_token_sum_norm``: computes the sum of token losses for each sequence, normalizes by the max sequence length (computed as ``cfg.generator.max_input_length + cfg.generator.sampling_params.max_generate_length``), and then averages over the batch. This is used in `Dr. GRPO `_. + +- ``algorithm.grpo_norm_by_std``: Whether to normalize advantages by the standard deviation in GRPO. This is set to ``false`` in `Dr. GRPO `_. - ``algorithm.lambd``: Lambda parameter for GAE. - ``algorithm.gamma``: Gamma parameter for GAE. - ``algorithm.eps_clip_low``: Lower bound for PPO clipping. diff --git a/skyrl-train/examples/algorithm/drgrpo/run_drgrpo_gsm8k.sh b/skyrl-train/examples/algorithm/drgrpo/run_drgrpo_gsm8k.sh new file mode 100644 index 0000000000..faf0e60540 --- /dev/null +++ b/skyrl-train/examples/algorithm/drgrpo/run_drgrpo_gsm8k.sh @@ -0,0 +1,61 @@ +set -x + +# Colocated Dr. GRPO training+generation for Qwen2.5-1.5B-Instruct on GSM8K. + +# uv run examples/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k +# export WANDB_API_KEY= +# bash examples/algorithm/drgrpo/run_drgrpo_gsm8k.sh + +# TODO (erictang000): add a description of the algorithm once GRPO docs are added. + +DATA_DIR="$HOME/data/gsm8k" +NUM_GPUS=4 +LOGGER="wandb" # change to "console" to print to stdout + +# Dr. GRPO parameters + +LOSS_REDUCTION="seq_mean_token_sum_norm" +GRPO_NORM_BY_STD=false +USE_KL_LOSS=false + +uv run --isolated --extra vllm -m skyrl_train.entrypoints.main_base \ + data.train_data="['$DATA_DIR/train.parquet']" \ + data.val_data="['$DATA_DIR/validation.parquet']" \ + trainer.algorithm.advantage_estimator="grpo" \ + trainer.algorithm.loss_reduction=$LOSS_REDUCTION \ + trainer.algorithm.grpo_norm_by_std=$GRPO_NORM_BY_STD \ + trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ + trainer.policy.model.path="Qwen/Qwen2.5-1.5B-Instruct" \ + trainer.placement.colocate_all=true \ + trainer.strategy=fsdp2 \ + trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \ + trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \ + generator.num_inference_engines=$NUM_GPUS \ + generator.inference_engine_tensor_parallel_size=1 \ + trainer.epochs=20 \ + trainer.eval_batch_size=1024 \ + trainer.eval_before_train=true \ + trainer.eval_interval=5 \ + trainer.update_epochs_per_batch=1 \ + trainer.train_batch_size=1024 \ + trainer.policy_mini_batch_size=256 \ + trainer.micro_forward_batch_size_per_gpu=64 \ + trainer.micro_train_batch_size_per_gpu=64 \ + trainer.ckpt_interval=10 \ + trainer.max_prompt_length=512 \ + generator.sampling_params.max_generate_length=1024 \ + trainer.policy.optimizer_config.lr=1.0e-6 \ + generator.backend=vllm \ + generator.run_engines_locally=true \ + generator.weight_sync_backend=nccl \ + generator.async_engine=true \ + generator.batched=true \ + environment.env_class=gsm8k \ + generator.n_samples_per_prompt=5 \ + generator.gpu_memory_utilization=0.8 \ + trainer.logger="$LOGGER" \ + trainer.project_name="gsm8k" \ + trainer.run_name="gsm8k_drgrpo" \ + trainer.resume_mode=null \ + trainer.ckpt_path="$HOME/ckpts/gsm8k_1.5B_ckpt" \ + $@ diff --git a/skyrl-train/examples/ppo/run_ppo.sh b/skyrl-train/examples/ppo/run_ppo.sh index 99793f8577..353d28bf31 100644 --- a/skyrl-train/examples/ppo/run_ppo.sh +++ b/skyrl-train/examples/ppo/run_ppo.sh @@ -43,7 +43,7 @@ uv run --isolated --extra vllm -m skyrl_train.entrypoints.main_base \ generator.gpu_memory_utilization=0.8 \ trainer.logger="wandb" \ trainer.project_name="gsm8k" \ - trainer.run_name="gsm8k_test" \ + trainer.run_name="gsm8k_ppo" \ trainer.resume_mode=null \ trainer.ckpt_path="$HOME/ckpts/gsm8k_1.5B_ckpt_ppo" \ trainer.eval_batch_size=1024 \ diff --git a/skyrl-train/skyrl_train/config/ppo_base_config.yaml b/skyrl-train/skyrl_train/config/ppo_base_config.yaml index eea919a9df..630f4d03bc 100644 --- a/skyrl-train/skyrl_train/config/ppo_base_config.yaml +++ b/skyrl-train/skyrl_train/config/ppo_base_config.yaml @@ -89,7 +89,8 @@ trainer: advantage_batch_normalize: false value_head_prefix: "value_head" policy_loss_type: "regular" # "regular", "dual_clip", "gspo", or customizable with PolicyLossRegistry - loss_reduction: "token_mean" # "token_mean", "sequence_mean" + loss_reduction: "token_mean" # "token_mean", "sequence_mean", "seq_mean_token_sum_norm" + grpo_norm_by_std: true # set to false to disable normalization by std in GRPO # GAE parameters lambd: 1.0 gamma: 1.0 diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py index 7b558755e2..65808e1007 100644 --- a/skyrl-train/skyrl_train/trainer.py +++ b/skyrl-train/skyrl_train/trainer.py @@ -756,6 +756,7 @@ def compute_advantages_and_returns(self, data: TrainingInputBatch) -> TrainingIn values=data["values"], gamma=self.cfg.trainer.algorithm.gamma, lambd=self.cfg.trainer.algorithm.lambd, + grpo_norm_by_std=self.cfg.trainer.algorithm.grpo_norm_by_std, ) data["returns"] = returns data["advantages"] = advantages diff --git a/skyrl-train/skyrl_train/utils/ppo_utils.py b/skyrl-train/skyrl_train/utils/ppo_utils.py index 3591f7cad6..d61943c53e 100644 --- a/skyrl-train/skyrl_train/utils/ppo_utils.py +++ b/skyrl-train/skyrl_train/utils/ppo_utils.py @@ -20,7 +20,6 @@ from enum import StrEnum from typing import Callable, List, Tuple, Union, Optional, Literal from functools import wraps - import torch import numpy as np @@ -160,6 +159,28 @@ def masked_whiten(values, mask, shift_mean=True): return whitened +def ppo_critic_loss( + values: torch.Tensor, + old_values: torch.Tensor, + returns: torch.Tensor, + config: DictConfig, + loss_mask: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, Optional[float]]: + + if config.value_clip is not None: + values_clipped = old_values + (values - old_values).clamp(-config.value_clip, config.value_clip) + surr1 = (values_clipped - returns) ** 2 + surr2 = (values - returns) ** 2 + loss = torch.max(surr1, surr2) + clipfrac = masked_mean((surr1 > surr2).float(), loss_mask).mean().detach().item() + else: + clipfrac = None + loss = (values - returns) ** 2 + + loss = masked_mean(loss, loss_mask, dim=-1).mean() + return 0.5 * loss, clipfrac + + # Shared registry actor class for both policy loss and advantage estimator registries @ray.remote class RegistryActor: @@ -468,7 +489,8 @@ def ppo_policy_loss( assert loss_reduction in [ "token_mean", "sequence_mean", - ], "loss_reduction must be either 'token_mean' or 'sequence_mean'" + "seq_mean_token_sum_norm", + ], "loss_reduction must be either 'token_mean', 'sequence_mean', or 'seq_mean_token_sum_norm'" ratio = (log_probs - old_log_probs).exp() surr1 = ratio * advantages @@ -480,7 +502,7 @@ def ppo_policy_loss( pg_losses3 = -advantages * config.clip_ratio_c clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1) loss = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) - loss = reduce_loss(loss, loss_mask, loss_reduction) + loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) return loss, clip_ratio @@ -538,13 +560,16 @@ def gspo_policy_loss( # Compute clipping ratio for monitoring clip_ratio = masked_mean((-surr2 > -surr1).float(), loss_mask).mean().detach().item() - loss = reduce_loss(loss, loss_mask, loss_reduction) + loss = reduce_loss(loss, loss_mask, loss_reduction, config.max_seq_len) return loss, clip_ratio def reduce_loss( - loss: torch.Tensor, loss_mask: Optional[torch.Tensor], loss_reduction: Literal["token_mean", "sequence_mean"] + loss: torch.Tensor, + loss_mask: Optional[torch.Tensor], + loss_reduction: Literal["token_mean", "sequence_mean", "seq_mean_token_sum_norm"], + max_seq_len: Optional[int] = None, ) -> torch.Tensor: if loss_reduction == "token_mean": # sum over *all* valid tokens, divide by total valid-token count @@ -552,6 +577,17 @@ def reduce_loss( elif loss_reduction == "sequence_mean": # per-sequence token-mean (dim=-1), then batch-mean loss = masked_mean(loss, loss_mask, dim=-1).mean() + elif loss_reduction == "seq_mean_token_sum_norm": + # per-sequence token-sum, normalized by the max sequence length, then batch mean + # this is the Dr. GRPO loss reduction to avoid length bias by normalizing by a constant + assert max_seq_len is not None, "max_seq_len must be provided for seq_mean_token_sum_norm loss reduction" + # NOTE: max_seq_len is computed as cfg.generator.max_input_length + cfg.generator.sampling_params.max_generate_length by default + if loss_mask is not None: + seq_losses = torch.sum(loss * loss_mask, dim=-1) / max_seq_len + else: + # If no mask, assume all tokens are valid + seq_losses = torch.sum(loss, dim=-1) / max_seq_len + loss = torch.mean(seq_losses) else: raise ValueError(f"Invalid loss reduction type: {loss_reduction}") return loss @@ -594,7 +630,7 @@ def compute_grpo_outcome_advantage( response_mask: torch.Tensor, index: np.ndarray, epsilon: float = 1e-6, - norm_adv_by_std_in_grpo: bool = True, + grpo_norm_by_std: bool = True, **kwargs, ): """ @@ -605,7 +641,7 @@ def compute_grpo_outcome_advantage( - response_mask: Float[torch.Tensor, "batch_size seqlen"] - index: np.ndarray (batch_size) - epsilon: float - - norm_adv_by_std_in_grpo: bool + - grpo_norm_by_std: bool Returns: - advantages: Float[torch.Tensor, "batch_size seqlen"] @@ -632,7 +668,7 @@ def compute_grpo_outcome_advantage( else: raise ValueError(f"no score in prompt index: {idx}") for i in range(bsz): - if norm_adv_by_std_in_grpo: + if grpo_norm_by_std: scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) else: scores[i] = scores[i] - id2mean[index[i]] @@ -647,7 +683,7 @@ def compute_advantages_and_returns( index: np.ndarray, adv_estimator: AdvantageEstimator, values: Optional[torch.Tensor] = None, - norm_adv_by_std_in_grpo: bool = True, + grpo_norm_by_std: bool = True, gamma=1.0, lambd=1.0, ): @@ -658,7 +694,7 @@ def compute_advantages_and_returns( response_mask=response_mask, index=index, values=values, - norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + grpo_norm_by_std=grpo_norm_by_std, gamma=gamma, lambd=lambd, ) diff --git a/skyrl-train/skyrl_train/utils/utils.py b/skyrl-train/skyrl_train/utils/utils.py index 956a6887a5..f7b668a255 100644 --- a/skyrl-train/skyrl_train/utils/utils.py +++ b/skyrl-train/skyrl_train/utils/utils.py @@ -4,7 +4,7 @@ import ray import torch from loguru import logger -from omegaconf.dictconfig import DictConfig +from omegaconf import DictConfig, OmegaConf from ray.util.placement_group import placement_group, PlacementGroupSchedulingStrategy, PlacementGroup from skyrl_train.utils.ppo_utils import AdvantageEstimatorRegistry, PolicyLossRegistry, sync_registries @@ -195,7 +195,18 @@ def validate_cfg(cfg: DictConfig): assert cfg.trainer.algorithm.loss_reduction in ( "token_mean", "sequence_mean", - ), f"invalid loss_reduction: {cfg.trainer.algorithm.loss_reduction}. Must be one of `['token_mean', 'sequence_mean']`" + "seq_mean_token_sum_norm", + ), f"invalid loss_reduction: {cfg.trainer.algorithm.loss_reduction}. Must be one of `['token_mean', 'sequence_mean', 'seq_mean_token_sum_norm']`" + + # add field to algorithm config needed for loss functions + # create a new config to make it modifiable + algorithm_config = OmegaConf.create(cfg.trainer.algorithm) + # NOTE (erictang000): this is the max sequence length including the prompt, since max response length + # per batch can be variable based on the prompt length. This is used to normalize the loss for + # seq_mean_token_sum_norm loss reduction. Potentially revisit this if we update to use a + # fixed max response budget. + algorithm_config.max_seq_len = cfg.generator.max_input_length + cfg.generator.sampling_params.max_generate_length + cfg.trainer.algorithm = algorithm_config if cfg.trainer.strategy == "deepspeed" and not ( cfg.trainer.policy.optimizer_config.offload_after_step diff --git a/skyrl-train/skyrl_train/workers/deepspeed/deepspeed_worker.py b/skyrl-train/skyrl_train/workers/deepspeed/deepspeed_worker.py index 8084a2cc5e..0f9af556eb 100644 --- a/skyrl-train/skyrl_train/workers/deepspeed/deepspeed_worker.py +++ b/skyrl-train/skyrl_train/workers/deepspeed/deepspeed_worker.py @@ -4,10 +4,8 @@ import ray import torch import torch.distributed -from functools import partial from loguru import logger from transformers import AutoModel - from transformers.trainer import get_scheduler @@ -15,13 +13,11 @@ from skyrl_train.distributed.deepspeed_strategy import DeepspeedStrategy from skyrl_train.utils import get_physical_gpu_id from skyrl_train.utils.utils import str_to_torch_dtype -from skyrl_train.utils.ppo_utils import PolicyLossRegistry from skyrl_train.workers.worker import ( PolicyWorkerBase, CriticWorkerBase, RewardWorkerBase, RefWorkerBase, - ValueLoss, ) @@ -94,10 +90,6 @@ def init_model(self, model_id_or_path): (actor, actor_optim, actor_scheduler), ) - # set ppo loss function - policy_loss_func = PolicyLossRegistry.get(self.cfg.trainer.algorithm.policy_loss_type) - self.actor_loss_fn = partial(policy_loss_func, config=self.cfg.trainer.algorithm) - self.use_cuda_ipc = False if self.cfg.generator.weight_sync_backend == "nccl" and self.cfg.trainer.placement.colocate_all: self.use_cuda_ipc = True @@ -287,9 +279,6 @@ def init_model(self, model_id_or_path): (critic, critic_optim, critic_scheduler), ) - # set ppo loss function - self.critic_loss_fn = ValueLoss(self.cfg.trainer.algorithm.value_clip) - class DeepSpeedRewardWorkerBase(RewardWorkerBase): def offload_to_cpu(self, pin_memory=True, non_blocking=True): diff --git a/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py b/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py index a15b1eb7ae..4c2a632b12 100644 --- a/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py +++ b/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py @@ -4,7 +4,6 @@ import ray import torch import torch.distributed -from functools import partial from transformers import AutoModel, AutoConfig from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP @@ -18,7 +17,6 @@ from skyrl_train.models import Actor, get_llm_for_sequence_regression from skyrl_train.distributed.fsdp_strategy import FSDPStrategy from skyrl_train.utils import get_physical_gpu_id, str_to_torch_dtype -from skyrl_train.utils.ppo_utils import PolicyLossRegistry from skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch from skyrl_train.distributed.fsdp_utils import fsdp_version, get_init_weight_context_manager from skyrl_train.workers.worker import ( @@ -26,7 +24,6 @@ CriticWorkerBase, RewardWorkerBase, RefWorkerBase, - ValueLoss, ) @@ -88,10 +85,6 @@ def init_model(self, model_path): self.optimizer is not None and self.scheduler is not None ), "FSDP preparation should create optimizer and scheduler" - # set ppo loss function - policy_loss_func = PolicyLossRegistry.get(self.cfg.trainer.algorithm.policy_loss_type) - self.actor_loss_fn = partial(policy_loss_func, config=self.cfg.trainer.algorithm) - self.use_cuda_ipc = False if self.cfg.generator.weight_sync_backend == "nccl" and self.cfg.trainer.placement.colocate_all: self.use_cuda_ipc = True @@ -267,9 +260,6 @@ def init_model(self, model_path): ) assert self.optimizer is not None - # set ppo loss function - self.critic_loss_fn = ValueLoss(self.cfg.trainer.algorithm.value_clip) - def forward( self, data: TrainingInputBatch, diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 6bf09eeabb..fe858581ed 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -2,7 +2,7 @@ import logging import os import socket -from typing import Dict, Optional, Type, List, Any +from typing import Dict, Optional, Type, List, Any, Callable from ctypes import CDLL, POINTER, Structure, c_char_p, c_int, c_ulong, c_void_p from tqdm import tqdm from collections import defaultdict @@ -24,6 +24,7 @@ from skyrl_train.distributed.ulysses import set_ulysses_sequence_parallel_group, apply_monkey_patch from skyrl_train.distributed.utils import init_custom_process_group from skyrl_train.utils.torch_utils import chunked_entropy_from_logits +from skyrl_train.utils.ppo_utils import PolicyLossRegistry, ppo_critic_loss from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics from skyrl_train.dataset.replay_buffer import Experience from skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch @@ -315,37 +316,6 @@ def _forward_micro_batch(self, micro_batch: TrainingInputBatch) -> TrainingOutpu raise NotImplementedError() -class ValueLoss(nn.Module): - """ - Value Loss for PPO - """ - - def __init__(self, clip_eps: float = None) -> None: - super().__init__() - self.clip_eps = clip_eps - - def forward( - self, - values: torch.Tensor, - old_values: torch.Tensor, - returns: torch.Tensor, - loss_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - - if self.clip_eps is not None: - values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps) - surr1 = (values_clipped - returns) ** 2 - surr2 = (values - returns) ** 2 - loss = torch.max(surr1, surr2) - clipfrac = masked_mean((surr1 > surr2).float(), loss_mask).mean().detach().item() - else: - clipfrac = None - loss = (values - returns) ** 2 - - loss = masked_mean(loss, loss_mask, dim=-1).mean() - return 0.5 * loss, clipfrac - - # adapted from OpenReasonerZero: https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero/blob/main/orz/ppo/actors.py class PPORayActorGroup: """ @@ -604,7 +574,7 @@ def __init__(self, **kwargs): self.strategy: DistributedStrategy = None self.record_memory: bool = False self.mesh_rank: MeshRank = None - self.actor_loss_fn: nn.Module = None + self.policy_loss_fn: Callable = PolicyLossRegistry.get(self.cfg.trainer.algorithm.policy_loss_type) def _normalize_mini_batch_size(self): """ @@ -728,10 +698,11 @@ def training_step(self, experience: Experience, global_step, local_step, accumul ) # loss function # TODO: recompute advantages - actor_loss, clip_ratio = self.actor_loss_fn( + policy_loss, clip_ratio = self.policy_loss_fn( action_log_probs, old_action_log_probs, advantages, + config=self.cfg.trainer.algorithm, loss_mask=loss_mask, ) # entropy @@ -757,7 +728,7 @@ def training_step(self, experience: Experience, global_step, local_step, accumul else: kl_loss = torch.tensor(0.0) - loss = actor_loss + kl_loss * self.cfg.trainer.algorithm.kl_loss_coef + loss = policy_loss + kl_loss * self.cfg.trainer.algorithm.kl_loss_coef loss = loss / accumulation_steps self.strategy.backward(loss, self.model, self.optimizer) @@ -772,7 +743,7 @@ def training_step(self, experience: Experience, global_step, local_step, accumul # status status = { - "policy_loss": actor_loss.item(), + "policy_loss": policy_loss.item(), "policy_lr": self.scheduler.get_last_lr()[0], "ppo_clip_ratio": clip_ratio, "policy_entropy": entropy, @@ -859,7 +830,7 @@ def __init__(self, **kwargs): self.strategy: DistributedStrategy = None self.record_memory: bool = False self.mesh_rank: MeshRank = None - self.critic_loss_fn: nn.Module = None + self.critic_loss_fn: Callable = ppo_critic_loss def _normalize_mini_batch_size(self): """ @@ -978,6 +949,7 @@ def training_step(self, experience: Experience, global_step, local_step, accumul values, old_values, returns, + config=self.cfg.trainer.algorithm, loss_mask=loss_mask, ) loss = loss / accumulation_steps diff --git a/skyrl-train/tests/cpu/algorithms/test_losses.py b/skyrl-train/tests/cpu/algorithms/test_losses.py index ec1af38d17..d3e8979d80 100644 --- a/skyrl-train/tests/cpu/algorithms/test_losses.py +++ b/skyrl-train/tests/cpu/algorithms/test_losses.py @@ -32,6 +32,7 @@ def test_policy_loss_dual_clip(): "clip_ratio_c": 3.0, "policy_loss_type": "dual_clip", "loss_reduction": "token_mean", + "max_seq_len": 4, } ) @@ -104,6 +105,7 @@ def test_policy_loss_reduction_modes(): "clip_ratio_c": 3.0, "policy_loss_type": "regular", "loss_reduction": "token_mean", + "max_seq_len": 4, } ) @@ -114,6 +116,7 @@ def test_policy_loss_reduction_modes(): "clip_ratio_c": 3.0, "policy_loss_type": "regular", "loss_reduction": "sequence_mean", + "max_seq_len": 4, } ) @@ -185,6 +188,7 @@ def test_policy_loss_reduction_edge_cases(): "clip_ratio_c": 3.0, "policy_loss_type": "regular", "loss_reduction": "token_mean", + "max_seq_len": 4, } ) @@ -195,6 +199,7 @@ def test_policy_loss_reduction_edge_cases(): "clip_ratio_c": 3.0, "policy_loss_type": "regular", "loss_reduction": "sequence_mean", + "max_seq_len": 4, } ) @@ -279,6 +284,7 @@ def test_gspo_importance_sampling_levels(): "clip_ratio_c": 3.0, "policy_loss_type": "regular", "loss_reduction": "token_mean", + "max_seq_len": 4, } ) ppo_loss_fn = PolicyLossRegistry.get("regular") @@ -292,6 +298,7 @@ def test_gspo_importance_sampling_levels(): "clip_ratio_c": 3.0, "policy_loss_type": "gspo", "loss_reduction": "sequence_mean", # GSPO recommended reduction + "max_seq_len": 4, } ) gspo_loss_fn = PolicyLossRegistry.get("gspo") diff --git a/skyrl-train/tests/cpu/test_trainer.py b/skyrl-train/tests/cpu/test_trainer.py index 8a458c354f..0f41971e34 100644 --- a/skyrl-train/tests/cpu/test_trainer.py +++ b/skyrl-train/tests/cpu/test_trainer.py @@ -87,6 +87,7 @@ def dummy_config(): "normalize_reward": True, "policy_loss_type": "regular", "loss_reduction": "token_mean", + "grpo_norm_by_std": True, }, "resume_mode": "none", }, @@ -254,6 +255,9 @@ def create_policy_worker_with_config( "train_batch_size": train_batch_size, "policy_mini_batch_size": policy_mini_batch_size, "micro_train_batch_size_per_gpu": micro_train_batch_size_per_gpu, + "algorithm": { + "policy_loss_type": "regular", + }, }, "generator": { "n_samples_per_prompt": n_samples_per_prompt, @@ -565,6 +569,9 @@ def test_ppo_train_batch_calculations(): "trainer": { "micro_train_batch_size_per_gpu": 2, "update_epochs_per_batch": 1, + "algorithm": { + "policy_loss_type": "regular", + }, }, "generator": { "sampling_params": { diff --git a/skyrl-train/tests/cpu/utils/test_ppo_utils.py b/skyrl-train/tests/cpu/utils/test_ppo_utils.py index 5ac86e11ec..16205e97f0 100644 --- a/skyrl-train/tests/cpu/utils/test_ppo_utils.py +++ b/skyrl-train/tests/cpu/utils/test_ppo_utils.py @@ -7,6 +7,7 @@ import math import pytest from skyrl_train.utils.ppo_utils import ( + reduce_loss, compute_approx_kl, compute_gae_advantage_return, compute_grpo_outcome_advantage, @@ -65,6 +66,35 @@ def test_compute_grpo_outcome_advantage(advantage_test_data): assert torch.allclose(adv, ret), "Advantages and returns should be equal with GRPO" +def test_compute_grpo_outcome_advantage_norm_std_false(): + """Test GRPO advantage computation with grpo_norm_by_std=False.""" + # Two groups: [6.0, 3.0] mean=4.5, [9.0, 12.0] mean=10.5 + token_level_rewards = torch.tensor( + [ + [1.0, 2.0, 3.0], # sum = 6.0, group 0 + [1.0, 1.0, 1.0], # sum = 3.0, group 0 + [3.0, 3.0, 3.0], # sum = 9.0, group 1 + [4.0, 4.0, 4.0], # sum = 12.0, group 1 + ] + ) + response_mask = torch.ones_like(token_level_rewards) + index = np.array([0, 0, 1, 1]) + + adv, ret = compute_grpo_outcome_advantage( + token_level_rewards=token_level_rewards, + response_mask=response_mask, + index=index, + grpo_norm_by_std=False, + ) + + # Expected: [6.0-4.5, 3.0-4.5, 9.0-10.5, 12.0-10.5] = [1.5, -1.5, -1.5, 1.5] + expected = torch.tensor([1.5, -1.5, -1.5, 1.5]).unsqueeze(-1) * response_mask + + assert adv.shape == token_level_rewards.shape + assert torch.allclose(adv, ret), "Advantages and returns should be equal with GRPO" + assert torch.allclose(adv, expected, atol=1e-5), f"Expected {expected}, got {adv}" + + def test_compute_gae_advantage_return(advantage_test_data): rewards, values, response_mask, index = advantage_test_data @@ -135,6 +165,32 @@ def test_compute_gae_advantage_return_lam(advantage_test_data): assert torch.allclose(ret, expected_ret, atol=1e-5) +def test_reduce_loss(): + """Test the reduce_loss function with different reduction types.""" + # Test data: 2x3 loss tensor with different valid token counts per sequence + loss = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + loss_mask = torch.tensor([[1.0, 1.0, 1.0], [1.0, 0.0, 0.0]]) # seq0 has 3 tokens, seq1 has 1 token + + # Test token_mean: sum all valid losses / count valid tokens + # Valid losses: [1.0, 2.0, 3.0, 4.0], mean = 10.0/4 = 2.5 + result_token = reduce_loss(loss, loss_mask, "token_mean") + expected_token = torch.tensor(2.5) + assert torch.allclose(result_token, expected_token), f"Expected {expected_token}, got {result_token}" + + # Test sequence_mean: mean of per-sequence means + # Seq 0: (1.0 + 2.0 + 3.0) / 3 = 2.0, Seq 1: 4.0 / 1 = 4.0, batch mean = (2.0 + 4.0) / 2 = 3.0 + result_seq = reduce_loss(loss, loss_mask, "sequence_mean") + expected_seq = torch.tensor(3.0) + assert torch.allclose(result_seq, expected_seq), f"Expected {expected_seq}, got {result_seq}" + + # Test seq_mean_token_sum_norm: sum per sequence / max_len, then batch mean + # Seq 0: (1.0 + 2.0 + 3.0) / 4 = 1.5, Seq 1: 4.0 / 4 = 1.0, batch mean = (1.5 + 1.0) / 2 = 1.25 + max_seq_len = 4 + result_max = reduce_loss(loss, loss_mask, "seq_mean_token_sum_norm", max_seq_len) + expected_max = torch.tensor(1.25) + assert torch.allclose(result_max, expected_max), f"Expected {expected_max}, got {result_max}" + + def test_adaptive_kl_controller_update(): controller = AdaptiveKLController(init_kl_coef=0.2, target=0.1, horizon=100) controller.update(current=0.2, n_steps=10)