-
Notifications
You must be signed in to change notification settings - Fork 260
[Trainer] Support registering custom advantage estimators #115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7f65eba
657b1c6
0afdd92
1e3fbbb
750ee6b
c1ec3f6
5066fe1
9aa7fa5
1cf556b
25abff6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| """ | ||
| uv run --isolated --extra vllm -m examples.algorithm.custom_advantage_estimator.main_custom_adv_est | ||
| """ | ||
|
|
||
| import ray | ||
| import hydra | ||
| import torch | ||
| import numpy as np | ||
| from omegaconf import DictConfig | ||
| from skyrl_train.utils import initialize_ray | ||
| from skyrl_train.entrypoints.main_base import BasePPOExp, config_dir, validate_cfg | ||
| from skyrl_train.utils.ppo_utils import AdvantageEstimatorRegistry | ||
|
|
||
|
|
||
| # Example of custom advantage estimator: "simple_baseline" | ||
| def compute_simple_baseline_advantage( | ||
| token_level_rewards: torch.Tensor, response_mask: torch.Tensor, index: np.ndarray, **kwargs | ||
| ): | ||
| """ | ||
| A simple custom advantage estimator that uses response-level rewards | ||
| and computes advantages against a simple baseline. | ||
|
|
||
| This is just an example - replace with your own logic. | ||
| """ | ||
| with torch.no_grad(): | ||
| response_rewards = (token_level_rewards * response_mask).sum(dim=-1, keepdim=True) | ||
|
|
||
| # Simple baseline: use the mean reward across the batch | ||
| baseline = response_rewards.mean() | ||
| advantages = (response_rewards - baseline) * response_mask | ||
| returns = advantages.clone() | ||
|
|
||
| return advantages, returns | ||
|
|
||
|
|
||
| @ray.remote(num_cpus=1) | ||
| def skyrl_entrypoint(cfg: DictConfig): | ||
| # Register the custom advantage estimator | ||
| AdvantageEstimatorRegistry.register("simple_baseline", compute_simple_baseline_advantage) | ||
|
|
||
| # make sure that the training loop is not run on the head node. | ||
| exp = BasePPOExp(cfg) | ||
| exp.run() | ||
|
|
||
|
|
||
| @hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None) | ||
| def main(cfg: DictConfig) -> None: | ||
| # validate the arguments | ||
| validate_cfg(cfg) | ||
|
|
||
| initialize_ray(cfg) | ||
| ray.get(skyrl_entrypoint.remote(cfg)) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| set -x | ||
|
|
||
| # Example of custom advantage estimator: "simple_baseline" | ||
| # Colocated GRPO training+generation for Qwen2.5-1.5B-Instruct on GSM8K. | ||
|
|
||
| # uv run examples/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k | ||
| # export WANDB_API_KEY=<your_key_here> | ||
| # bash examples/gsm8k/run_gsm8k.sh | ||
|
|
||
| # NOTE (sumanthrh): `micro_train_batch_size_per_gpu` and `micro_forward_batch_size_per_gpu` can be tuned | ||
|
|
||
| DATA_DIR="$HOME/data/gsm8k" | ||
| NUM_GPUS=4 | ||
| LOGGER="wandb" # change to "console" to print to stdout | ||
|
|
||
| # Configure the advantage estimator to use | ||
| ADV_EST="simple_baseline" | ||
|
|
||
| uv run --isolated --extra vllm -m examples.algorithm.custom_advantage_estimator.main_custom_adv_est \ | ||
| data.train_data="['$DATA_DIR/train.parquet']" \ | ||
| data.val_data="['$DATA_DIR/validation.parquet']" \ | ||
| trainer.algorithm.advantage_estimator="$ADV_EST" \ | ||
| trainer.policy.model.path="Qwen/Qwen2.5-0.5B-Instruct" \ | ||
| trainer.placement.colocate_all=true \ | ||
| trainer.strategy=fsdp2 \ | ||
| trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \ | ||
| trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \ | ||
| generator.num_inference_engines=$NUM_GPUS \ | ||
| generator.inference_engine_tensor_parallel_size=1 \ | ||
| trainer.epochs=20 \ | ||
| trainer.eval_batch_size=1024 \ | ||
| trainer.eval_before_train=true \ | ||
| trainer.eval_interval=5 \ | ||
| trainer.update_epochs_per_batch=1 \ | ||
| trainer.train_batch_size=1024 \ | ||
| trainer.policy_mini_batch_size=256 \ | ||
| trainer.micro_forward_batch_size_per_gpu=64 \ | ||
| trainer.micro_train_batch_size_per_gpu=64 \ | ||
| trainer.ckpt_interval=10 \ | ||
| trainer.max_prompt_length=512 \ | ||
| generator.sampling_params.max_generate_length=1024 \ | ||
| trainer.policy.optimizer_config.lr=1.0e-6 \ | ||
| trainer.algorithm.use_kl_loss=true \ | ||
| generator.backend=vllm \ | ||
| generator.run_engines_locally=true \ | ||
| generator.weight_sync_backend=nccl \ | ||
| generator.async_engine=true \ | ||
| generator.batched=true \ | ||
| environment.env_class=gsm8k \ | ||
| generator.n_samples_per_prompt=5 \ | ||
| generator.gpu_memory_utilization=0.8 \ | ||
| trainer.logger="$LOGGER" \ | ||
| trainer.project_name="custom_adv_est_gsm8k" \ | ||
| trainer.run_name="custom_adv_est_gsm8k_test" \ | ||
| trainer.resume_mode=null \ | ||
| trainer.ckpt_path="$HOME/ckpts/gsm8k_1.5B_ckpt" \ | ||
| $@ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -140,10 +140,6 @@ def validate_cfg(cfg: DictConfig): | |
| cfg.trainer.sequence_parallel_backend == "ulysses" | ||
| ), f"only ulysses is supported as of now, got {cfg.trainer.sequence_parallel_backend}" | ||
|
|
||
| assert cfg.trainer.algorithm.advantage_estimator in ( | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: this is an unfortunate deletion. The tricky part is that the adv est registration has to happen on a Ray worker (or passed into the ray workers), but our config validation currently happens before calling into the head ray process, so we can't actually check the configured This will be a problem with any registries we later add, so I opened an issue on this: #116 . One option is to move config validation to inside the ray head worker.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm yeah, I guess we don't run any validation for env names for the same reason... this makes sense for now.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeahh. Updated the issue (#116) with the env name validation point, too. |
||
| "gae", | ||
| "grpo", | ||
| ), f"invalid advantage estimator: {cfg.trainer.algorithm.advantage_estimator}" | ||
| # if advantage estimator is GAE, then critic path should be provided | ||
| if cfg.trainer.algorithm.advantage_estimator == "gae": | ||
| assert ( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you think we should add this example to a docs page? Maybe it's time for us to initialize a docs section for "Algorithms", and this small example could be the "Custom algorithms" page under that section.
We can add explainers of the relevant configs to tune for each algorithm in each of the doc pages (i.e. for PPO, GRPO, DAPO, GSPO, etc..) - wdyt?
At the very least should update the https://github.com/NovaSky-AI/SkyRL/blob/main/skyrl-train/docs/configuration/config.rst config page to mention the AdvEstimator registry
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes good call. Let me take this as a todo? I am working on one more DAPO feature and on GSPO. I will wrap these up in the next couple days and it will be a good time to introduce an "Algorithms" section.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Created issue at #119