-
Notifications
You must be signed in to change notification settings - Fork 55
Group-relative REINFORCE Families #292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
881aa4a
[algo] Implemented rec family.
yaochaorui 536b758
[example] REC on GSM8k.
yaochaorui ad8e466
[algo] New REC family members.
yaochaorui 5f2fa5f
[algo] New REP family members.
yaochaorui 89387b9
[fix]
yaochaorui a725328
[algo] New RED family members.
yaochaorui e8b0c71
[algo] New RED family member.
yaochaorui bbd6a01
[minor]
yaochaorui 7200fa9
[algo] REC family.
yaochaorui 5f1dfe6
[clean] so long!
yaochaorui 6d22577
[README] Group-relative REINFORCE family.
yaochaorui 296240b
Update README to reflect single config file
yaochaorui c17ebee
Update README to reflect config file changes
yaochaorui 9c72092
Correct config file description in README.md
yaochaorui e38e4b0
[minor] Added 'temp' in the default args.
yaochaorui c54bdd8
[minor] Use None instead of str "none".
yaochaorui c6d2af4
[example] Updated.
yaochaorui 2ae8f89
[Minor]
yaochaorui 9a1777f
[Minor]
yaochaorui File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,221 @@ | ||
| # Example: REC on GSM8k dataset | ||
|
|
||
| This example shows the usage of REC on the [GSM8k dataset](https://huggingface.co/datasets/openai/gsm8k). | ||
|
|
||
| For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_reasoning_basic.md). | ||
|
|
||
| The config file is located in [`gsm8k.yaml`](gsm8k.yaml). | ||
|
|
||
| # Group-relative REINFORCE Families | ||
| This folder provides **example configurations** for running different group-relative REINFORCE families within Trinity-RFT. | ||
|
|
||
| It includes three major families: | ||
|
|
||
| - **REC family** (clipping + importance sampling) | ||
| - **REP family** (regularization-based variants) | ||
| - **RED family** (data-distribution shaping strategies) | ||
|
|
||
| We also provide baseline implementations such as **Vanilla REINFORCE** and **GRPO**. | ||
|
|
||
| All algorithms are instantiated through modular YAML configs for easy reproduction and extension. | ||
|
|
||
| # Summary Table 📝 | ||
|
|
||
| | Family | Variants | Key Idea | | ||
| | ------------- | ----------------------------------------------- | ----------------------------------- | | ||
| | **Baselines** | REINFORCE, GRPO | Standard references | | ||
| | **REC** | OneSide-NoIS, OneSide-IS, TwoSide-IS, Ring-NoIS | Clipping + importance sampling | | ||
| | **REP** | AsymRE, OPMD | Regularization | | ||
| | **RED** | Drop, Weight | Data-distribution shaping | | ||
|
|
||
|
|
||
|
|
||
| # Instantiations | ||
|
|
||
| ## Baselines | ||
|
|
||
| ### REINFORCE | ||
| Vanilla REINFORCE with group mean as baseline. | ||
|
|
||
| ``` | ||
| algorithm: | ||
| algorithm_type: rec | ||
| policy_loss_fn_args: | ||
| epsilon_low: 0.2 | ||
| epsilon_high: 0.2 | ||
| clip_mode: "none" # no clipping | ||
| weight: "none" # uniform weighting for samples | ||
| temp: 1.0 | ||
| regularizer: "none" # no regularizer | ||
| regularizer_coef: 0.0 | ||
| advantage_fn_args: | ||
| std_normalize: false | ||
| ``` | ||
|
|
||
| ### GRPO | ||
| GRPO implemented with zero KL regularizer. Regularization can be enabled via `kl_loss_fn` and `kl_loss_fn_args`. | ||
|
|
||
| ``` | ||
| algorithm: | ||
| algorithm_type: rec | ||
| policy_loss_fn_args: | ||
| epsilon_low: 0.2 | ||
| epsilon_high: 0.2 | ||
| clip_mode: "one-side" | ||
| weight: "importance_sampling" | ||
| temp: 1.0 | ||
| regularizer: "none" | ||
| regularizer_coef: 0.0 | ||
| advantage_fn_args: | ||
| std_normalize: true | ||
| kl_loss_fn: 'k2' | ||
| kl_loss_fn_args: | ||
| kl_coef: 0.0 | ||
|
|
||
| ``` | ||
|
|
||
| ## REC family | ||
| Variants of clipping and importance-sampling strategies. | ||
| - REC-OneSide-NoIS | ||
| - REC-OneSide-IS | ||
| - REC-TwoSide-IS | ||
| - REC-Ring-NoIS | ||
|
|
||
| ### REC-OneSide-NoIS | ||
|
|
||
| ``` | ||
| algorithm: | ||
| algorithm_type: rec | ||
| policy_loss_fn_args: | ||
| epsilon_low: 0.2 | ||
| epsilon_high: 0.2 | ||
| clip_mode: "one-side" | ||
| weight: "none" | ||
| temp: 1.0 | ||
| regularizer: "none" | ||
| regularizer_coef: 0.0 | ||
| advantage_fn_args: | ||
| std_normalize: false | ||
| ``` | ||
|
|
||
| ### REC-OneSide-IS | ||
|
|
||
| ``` | ||
| algorithm: | ||
| algorithm_type: rec | ||
| policy_loss_fn_args: | ||
| epsilon_low: 0.2 | ||
| epsilon_high: 0.2 | ||
| clip_mode: "one-side" | ||
| weight: "importance_sampling" | ||
| temp: 1.0 | ||
| regularizer: "none" | ||
| regularizer_coef: 0.0 | ||
| advantage_fn_args: | ||
| std_normalize: false | ||
| ``` | ||
|
|
||
| ### REC-TwoSide-IS | ||
|
|
||
| ``` | ||
| algorithm: | ||
| algorithm_type: rec | ||
| policy_loss_fn_args: | ||
| epsilon_low: 0.2 | ||
| epsilon_high: 0.2 | ||
| clip_mode: "two-side" | ||
| weight: "importance_sampling" | ||
| temp: 1.0 | ||
| regularizer: "none" | ||
| regularizer_coef: 0.0 | ||
| advantage_fn_args: | ||
| std_normalize: false | ||
| ``` | ||
| ### REC-Ring-NoIS | ||
|
|
||
| ``` | ||
| algorithm: | ||
| algorithm_type: rec | ||
| policy_loss_fn_args: | ||
| epsilon_low: 0.2 | ||
| epsilon_high: 0.2 | ||
| epsilon_low_prime: 0.6 | ||
| epsilon_high_prime: 2.0 | ||
| clip_mode: "ring" | ||
| weight: "none" | ||
| temp: 1.0 | ||
| regularizer: "none" | ||
| regularizer_coef: 0.0 | ||
| advantage_fn_args: | ||
| std_normalize: false | ||
| ``` | ||
|
|
||
| ## REP family | ||
|
|
||
| Regularization-based algorithms. | ||
| - AsymRE (forward KL regularization) | ||
| - Kimi’s OPMD (k2 regularizer) | ||
|
|
||
| ### AsymRE | ||
|
|
||
| ``` | ||
| algorithm: | ||
| algorithm_type: rec | ||
| policy_loss_fn_args: | ||
| clip_mode: "none" | ||
| weight: "none" | ||
| temp: 1.0 | ||
| regularizer: "forward-kl" | ||
| regularizer_coef: 0.1 | ||
| advantage_fn_args: | ||
| std_normalize: false | ||
| ``` | ||
|
|
||
|
|
||
| ### Kimi's OPMD | ||
|
|
||
| ``` | ||
| algorithm: | ||
| algorithm_type: rec | ||
| policy_loss_fn_args: | ||
| clip_mode: "none" | ||
| weight: "none" | ||
| regularizer: "k2" | ||
| regularizer_coef: 0.1 | ||
| advantage_fn_args: | ||
| std_normalize: false | ||
| ``` | ||
|
|
||
| ## RED family | ||
| Data-distribution shaping variants. | ||
| - RED-Drop (drop extra negative examples to balance the positive examples v.s. negative examples) | ||
| - RED-Weight (advantage-weighting strategy) | ||
|
|
||
| ### RED-Drop | ||
|
|
||
| ``` | ||
| algorithm: | ||
| algorithm_type: rec | ||
| policy_loss_fn_args: | ||
| clip_mode: "none" | ||
| weight: "none" | ||
| regularizer: "none" | ||
| advantage_fn_args: | ||
| std_normalize: false | ||
| drop: "balance" | ||
| ``` | ||
|
|
||
|
|
||
| ### RED-Weight | ||
|
|
||
| ``` | ||
| algorithm: | ||
| algorithm_type: rec | ||
| policy_loss_fn_args: | ||
| clip_mode: "none" | ||
| weight: "advantage" | ||
| regularizer: "none" | ||
| temp: 1.0 | ||
| advantage_fn_args: | ||
| std_normalize: false | ||
| ``` | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| # Configuration file for the REC GSM8k project. | ||
| project: "Trinity-RFT-GSM8K" | ||
| name: rec_gsm8k | ||
| checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} | ||
| mode: both | ||
| model: | ||
| model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-3B-Instruct} | ||
| max_response_tokens: 1024 | ||
| max_model_len: 1280 | ||
| algorithm: | ||
| algorithm_type: rec | ||
| policy_loss_fn_args: | ||
| epsilon_low: 0.2 | ||
| epsilon_high: 0.2 | ||
| clip_mode: "none" | ||
| weight: "none" | ||
| temp: 1.0 | ||
| regularizer: "none" | ||
| regularizer_coef: 0.0 | ||
| advantage_fn_args: | ||
| std_normalize: false | ||
| repeat_times: 8 | ||
| cluster: | ||
| node_num: 1 | ||
| gpu_per_node: 8 | ||
| buffer: | ||
| total_steps: 100 | ||
| batch_size: 96 | ||
| explorer_input: | ||
| taskset: | ||
| name: gsm8k | ||
| storage_type: file | ||
| path: ${oc.env:TRINITY_TASKSET_PATH} | ||
| split: train | ||
| format: | ||
| prompt_key: question | ||
| response_key: answer | ||
| rollout_args: | ||
| temperature: 1.0 | ||
| eval_tasksets: | ||
| - name: gsm8k-eval | ||
| storage_type: file | ||
| path: ${oc.env:TRINITY_EVAL_TASKSET_PATH} | ||
| split: test | ||
| format: | ||
| prompt_key: question | ||
| response_key: answer | ||
| default_workflow_type: math_workflow | ||
| trainer_input: | ||
| experience_buffer: | ||
| name: gsm8k_buffer | ||
| storage_type: queue | ||
| explorer: | ||
| eval_interval: 20 | ||
| runner_num: 64 | ||
| rollout_model: | ||
| engine_type: vllm_async | ||
| engine_num: 4 | ||
| tensor_parallel_size: 1 | ||
| enable_prefix_caching: false | ||
| enforce_eager: true | ||
| dtype: bfloat16 | ||
| seed: 42 | ||
| synchronizer: | ||
| sync_method: nccl | ||
| sync_interval: 20 | ||
| sync_timeout: 1200 | ||
| sync_offset: 0 | ||
| trainer: | ||
| trainer_type: verl | ||
| save_interval: 100 | ||
| trainer_config: | ||
| actor_rollout_ref: | ||
| model: | ||
| use_remove_padding: true | ||
| actor: | ||
| use_dynamic_bsz: true | ||
| ppo_max_token_len_per_gpu: 16384 | ||
| ulysses_sequence_parallel_size: 1 | ||
| optim: | ||
| lr: 1e-6 | ||
| ref: | ||
| log_prob_use_dynamic_bsz: ${trainer.trainer_config.actor_rollout_ref.actor.use_dynamic_bsz} | ||
| log_prob_max_token_len_per_gpu: ${trainer.trainer_config.actor_rollout_ref.actor.ppo_max_token_len_per_gpu} | ||
| ulysses_sequence_parallel_size: ${trainer.trainer_config.actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.