Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion skyrl-train/docs/configuration/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ Algorithm Configuration
value_clip: 0.2
normalize_reward: true

- ``algorithm.advantage_estimator``: Advantage estimator to use. Currently, we support ``grpo`` and ``gae``.
- ``algorithm.advantage_estimator``: Advantage estimator to use. We currently implement ``grpo`` and ``gae``, and custom advantage estimators can be registered with the ``AdvantageEstimatorRegistry``.
- ``algorithm.use_kl_estimator_k3``: Whether to use the k3 estimator for KL divergence calculation. The k3 estimator is the non negative kl approximation in `this blog post <http://joschu.net/blog/kl-approx.html>`_. Besides non negative, it is also unbiased and has lower variance.
- ``algorithm.use_abs_kl``: Whether to use the absolute KL divergence for KL divergence calculation.
- ``algorithm.use_kl_in_reward``: Whether to apply KL divergence penalty to rewards. The new rewards will be computed as ``rewards - kl * kl_loss_coef``.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think we should add this example to a docs page? Maybe it's time for us to initialize a docs section for "Algorithms", and this small example could be the "Custom algorithms" page under that section.

We can add explainers of the relevant configs to tune for each algorithm in each of the doc pages (i.e. for PPO, GRPO, DAPO, GSPO, etc..) - wdyt?

At the very least should update the https://github.com/NovaSky-AI/SkyRL/blob/main/skyrl-train/docs/configuration/config.rst config page to mention the AdvEstimator registry

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes good call. Let me take this as a todo? I am working on one more DAPO feature and on GSPO. I will wrap these up in the next couple days and it will be a good time to introduce an "Algorithms" section.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created issue at #119

uv run --isolated --extra vllm -m examples.algorithm.custom_advantage_estimator.main_custom_adv_est
"""

import ray
import hydra
import torch
import numpy as np
from omegaconf import DictConfig
from skyrl_train.utils import initialize_ray
from skyrl_train.entrypoints.main_base import BasePPOExp, config_dir, validate_cfg
from skyrl_train.utils.ppo_utils import AdvantageEstimatorRegistry


# Example of custom advantage estimator: "simple_baseline"
def compute_simple_baseline_advantage(
token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: np.ndarray, **kwargs
):
"""
A simple custom advantage estimator that uses response-level rewards
and computes advantages against a simple baseline.

This is just an example - replace with your own logic.
"""
with torch.no_grad():
response_rewards = (token_level_rewards * response_mask).sum(dim=-1, keepdim=True)

# Simple baseline: use the mean reward across the batch
baseline = response_rewards.mean()
advantages = (response_rewards - baseline) * response_mask
returns = advantages.clone()

return advantages, returns


@ray.remote(num_cpus=1)
def skyrl_entrypoint(cfg: DictConfig):
# Register the custom advantage estimator
AdvantageEstimatorRegistry.register("simple_baseline", compute_simple_baseline_advantage)

# make sure that the training loop is not run on the head node.
exp = BasePPOExp(cfg)
exp.run()


@hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None)
def main(cfg: DictConfig) -> None:
# validate the arguments
validate_cfg(cfg)

initialize_ray(cfg)
ray.get(skyrl_entrypoint.remote(cfg))


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
set -x

# Example of custom advantage estimator: "simple_baseline"
# Colocated GRPO training+generation for Qwen2.5-1.5B-Instruct on GSM8K.

# uv run examples/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k
# export WANDB_API_KEY=<your_key_here>
# bash examples/gsm8k/run_gsm8k.sh

# NOTE (sumanthrh): `micro_train_batch_size_per_gpu` and `micro_forward_batch_size_per_gpu` can be tuned

DATA_DIR="$HOME/data/gsm8k"
NUM_GPUS=4
LOGGER="wandb" # change to "console" to print to stdout

# Configure the advantage estimator to use
ADV_EST="simple_baseline"

uv run --isolated --extra vllm -m examples.algorithm.custom_advantage_estimator.main_custom_adv_est \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
trainer.algorithm.advantage_estimator="$ADV_EST" \
trainer.policy.model.path="Qwen/Qwen2.5-0.5B-Instruct" \
trainer.placement.colocate_all=true \
trainer.strategy=fsdp2 \
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \
generator.num_inference_engines=$NUM_GPUS \
generator.inference_engine_tensor_parallel_size=1 \
trainer.epochs=20 \
trainer.eval_batch_size=1024 \
trainer.eval_before_train=true \
trainer.eval_interval=5 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=1024 \
trainer.policy_mini_batch_size=256 \
trainer.micro_forward_batch_size_per_gpu=64 \
trainer.micro_train_batch_size_per_gpu=64 \
trainer.ckpt_interval=10 \
trainer.max_prompt_length=512 \
generator.sampling_params.max_generate_length=1024 \
trainer.policy.optimizer_config.lr=1.0e-6 \
trainer.algorithm.use_kl_loss=true \
generator.backend=vllm \
generator.run_engines_locally=true \
generator.weight_sync_backend=nccl \
generator.async_engine=true \
generator.batched=true \
environment.env_class=gsm8k \
generator.n_samples_per_prompt=5 \
generator.gpu_memory_utilization=0.8 \
trainer.logger="$LOGGER" \
trainer.project_name="custom_adv_est_gsm8k" \
trainer.run_name="custom_adv_est_gsm8k_test" \
trainer.resume_mode=null \
trainer.ckpt_path="$HOME/ckpts/gsm8k_1.5B_ckpt" \
$@
2 changes: 1 addition & 1 deletion skyrl-train/skyrl_train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ trainer:
fsdp_size: -1
sequence_parallel_size: 1
algorithm:
advantage_estimator: "grpo"
advantage_estimator: "grpo" # Customizable with AdvantageEstimatorRegistry
kl_target: null
init_kl_coef: 0.0
use_kl_estimator_k3: true
Expand Down
112 changes: 93 additions & 19 deletions skyrl-train/skyrl_train/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,84 @@

import torch
import numpy as np
from typing import Optional, Tuple
from typing import Optional, Tuple, Union, List, Callable, Dict
from enum import Enum
from skyrl_train.training_batch import TrainingInputBatch
from jaxtyping import Float
from collections import defaultdict


class AdvantageEstimator(Enum):
GAE = "gae"
GRPO = "grpo"

def __str__(self):
return self.value


class AdvantageEstimatorRegistry:
"""
Registry for advantage estimator functions.

This registry allows users to register custom advantage estimators without modifying
the skyrl_train package. Custom estimators can be registered by calling
AdvantageEstimatorRegistry.register() directly or by using the @register_advantage_estimator
decorator.

See examples/algorithm/custom_advantage_estimator for a simple example of how to
register and use custom advantage estimators.
"""

_estimators: Dict[str, Callable] = {}

@classmethod
def register(cls, name: Union[str, AdvantageEstimator], func: Callable):
"""Register an advantage estimator function."""
# Convert enum to string if needed
if isinstance(name, AdvantageEstimator):
name = name.value

if name in cls._estimators:
raise ValueError(f"Estimator '{name}' already registered")

cls._estimators[name] = func

@classmethod
def get(cls, name: str) -> Callable:
"""Get an estimator function by name."""
if name not in cls._estimators:
available = list(cls._estimators.keys())
raise ValueError(f"Unknown estimator '{name}'. Available: {available}")
return cls._estimators[name]

@classmethod
def list_available(cls) -> List[str]:
"""List all registered estimators."""
return list(cls._estimators.keys())

@classmethod
def unregister(cls, name: Union[str, AdvantageEstimator]):
"""Unregister an advantage estimator function. Useful for testing."""
# Convert enum to string if needed
if isinstance(name, AdvantageEstimator):
name = name.value

if name not in cls._estimators:
raise ValueError(f"Estimator '{name}' not registered")

del cls._estimators[name]


def register_advantage_estimator(name: Union[str, AdvantageEstimator]):
"""Decorator to register an advantage estimator function."""

def decorator(func: Callable):
AdvantageEstimatorRegistry.register(name, func)
return func

return decorator


# TODO (erictang000): unused right now, but will be useful as we add more algorithm support
class AdaptiveKLController:
"""
Expand Down Expand Up @@ -144,12 +216,14 @@ def masked_whiten(values, mask, shift_mean=True):
return whitened


@register_advantage_estimator(AdvantageEstimator.GAE)
def compute_gae_advantage_return(
token_level_rewards: Float[torch.Tensor, "batch_size seqlen"],
values: Float[torch.Tensor, "batch_size seqlen"],
response_mask: Float[torch.Tensor, "batch_size seqlen"],
gamma: float,
lambd: float,
**kwargs,
) -> Tuple[Float[torch.Tensor, "batch_size seqlen"], Float[torch.Tensor, "batch_size seqlen"]]:
"""
Compute advantage and return for GAE.
Expand All @@ -173,12 +247,14 @@ def compute_gae_advantage_return(
return advantages, returns


@register_advantage_estimator(AdvantageEstimator.GRPO)
def compute_grpo_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
norm_adv_by_std_in_grpo: bool = True,
**kwargs,
):
"""
Compute advantage for GRPO, operating only on Outcome reward (with only one scalar reward for each response).
Expand Down Expand Up @@ -228,27 +304,25 @@ def compute_advantages_and_returns(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
adv_estimator: str,
adv_estimator: Union[str, AdvantageEstimator],
values: Optional[torch.Tensor] = None,
norm_adv_by_std_in_grpo: bool = True,
gamma=1.0,
lambd=1.0,
):
if adv_estimator == "gae":
advantages, returns = compute_gae_advantage_return(
token_level_rewards=token_level_rewards,
values=values,
response_mask=response_mask,
gamma=gamma,
lambd=lambd,
)
elif adv_estimator == "grpo":
advantages, returns = compute_grpo_outcome_advantage(
token_level_rewards=token_level_rewards,
response_mask=response_mask,
index=index,
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
)
if isinstance(adv_estimator, AdvantageEstimator):
estimator_name = adv_estimator.value
else:
raise ValueError(f"Invalid adv_estimator: {adv_estimator}")
return advantages, returns
estimator_name = adv_estimator

estimator_func = AdvantageEstimatorRegistry.get(estimator_name)

return estimator_func(
token_level_rewards=token_level_rewards,
response_mask=response_mask,
index=index,
values=values,
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
gamma=gamma,
lambd=lambd,
)
4 changes: 0 additions & 4 deletions skyrl-train/skyrl_train/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,6 @@ def validate_cfg(cfg: DictConfig):
cfg.trainer.sequence_parallel_backend == "ulysses"
), f"only ulysses is supported as of now, got {cfg.trainer.sequence_parallel_backend}"

assert cfg.trainer.algorithm.advantage_estimator in (
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this is an unfortunate deletion. The tricky part is that the adv est registration has to happen on a Ray worker (or passed into the ray workers), but our config validation currently happens before calling into the head ray process, so we can't actually check the configured advantage_estimator against the contents of the registry.

This will be a problem with any registries we later add, so I opened an issue on this: #116 . One option is to move config validation to inside the ray head worker.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm yeah, I guess we don't run any validation for env names for the same reason... this makes sense for now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeahh. Updated the issue (#116) with the env name validation point, too.

"gae",
"grpo",
), f"invalid advantage estimator: {cfg.trainer.algorithm.advantage_estimator}"
# if advantage estimator is GAE, then critic path should be provided
if cfg.trainer.algorithm.advantage_estimator == "gae":
assert (
Expand Down
Loading