-
Notifications
You must be signed in to change notification settings - Fork 256
[algorithm] Support Dr. GRPO + refactor where policy/critic loss functions are set #133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
4adf474
doctor
erictang000 446c421
x
erictang000 79eed38
x
erictang000 6211347
add example
erictang000 68a9073
x
erictang000 5c3dbab
x
erictang000 ef68cab
x
erictang000 fc95e36
docs
erictang000 57ed3a4
thanks gemini
erictang000 7b180a0
x
erictang000 45c8dff
kl loss false
erictang000 d87631b
refactor critic loss fn and put policy + critic loss get in trainer
erictang000 1c639cf
move from init to setter method
erictang000 a53b09a
Merge branch 'main' of https://github.com/erictang000/SkyRL into dr_grpo
erictang000 053e09e
rename to seq_mean_token_sum_norm
erictang000 86f5bb6
add todo for putting dr grpo settings once we make docs describing al…
erictang000 1b61833
thanks gemini
erictang000 0b27d84
address comment
erictang000 f82ca0b
fix tests
erictang000 98ff01a
thanks gemini
erictang000 4705c80
comments
erictang000 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| set -x | ||
|
|
||
| # Colocated Dr. GRPO training+generation for Qwen2.5-1.5B-Instruct on GSM8K. | ||
|
|
||
| # uv run examples/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k | ||
| # export WANDB_API_KEY=<your_key_here> | ||
| # bash examples/algorithm/drgrpo/run_drgrpo_gsm8k.sh | ||
|
|
||
| # TODO (erictang000): add a description of the algorithm once GRPO docs are added. | ||
|
|
||
| DATA_DIR="$HOME/data/gsm8k" | ||
| NUM_GPUS=4 | ||
| LOGGER="wandb" # change to "console" to print to stdout | ||
|
|
||
| # Dr. GRPO parameters | ||
|
|
||
| LOSS_REDUCTION="seq_mean_token_sum_norm" | ||
| GRPO_NORM_BY_STD=false | ||
| USE_KL_LOSS=false | ||
|
|
||
| uv run --isolated --extra vllm -m skyrl_train.entrypoints.main_base \ | ||
| data.train_data="['$DATA_DIR/train.parquet']" \ | ||
| data.val_data="['$DATA_DIR/validation.parquet']" \ | ||
| trainer.algorithm.advantage_estimator="grpo" \ | ||
| trainer.algorithm.loss_reduction=$LOSS_REDUCTION \ | ||
| trainer.algorithm.grpo_norm_by_std=$GRPO_NORM_BY_STD \ | ||
| trainer.algorithm.use_kl_loss=$USE_KL_LOSS \ | ||
| trainer.policy.model.path="Qwen/Qwen2.5-1.5B-Instruct" \ | ||
| trainer.placement.colocate_all=true \ | ||
| trainer.strategy=fsdp2 \ | ||
| trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \ | ||
| trainer.placement.ref_num_gpus_per_node=$NUM_GPUS \ | ||
| generator.num_inference_engines=$NUM_GPUS \ | ||
| generator.inference_engine_tensor_parallel_size=1 \ | ||
| trainer.epochs=20 \ | ||
| trainer.eval_batch_size=1024 \ | ||
| trainer.eval_before_train=true \ | ||
| trainer.eval_interval=5 \ | ||
| trainer.update_epochs_per_batch=1 \ | ||
| trainer.train_batch_size=1024 \ | ||
| trainer.policy_mini_batch_size=256 \ | ||
| trainer.micro_forward_batch_size_per_gpu=64 \ | ||
| trainer.micro_train_batch_size_per_gpu=64 \ | ||
| trainer.ckpt_interval=10 \ | ||
| trainer.max_prompt_length=512 \ | ||
| generator.sampling_params.max_generate_length=1024 \ | ||
| trainer.policy.optimizer_config.lr=1.0e-6 \ | ||
| generator.backend=vllm \ | ||
| generator.run_engines_locally=true \ | ||
| generator.weight_sync_backend=nccl \ | ||
| generator.async_engine=true \ | ||
| generator.batched=true \ | ||
| environment.env_class=gsm8k \ | ||
| generator.n_samples_per_prompt=5 \ | ||
| generator.gpu_memory_utilization=0.8 \ | ||
| trainer.logger="$LOGGER" \ | ||
| trainer.project_name="gsm8k" \ | ||
| trainer.run_name="gsm8k_drgrpo" \ | ||
| trainer.resume_mode=null \ | ||
| trainer.ckpt_path="$HOME/ckpts/gsm8k_1.5B_ckpt" \ | ||
| $@ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.