-
Notifications
You must be signed in to change notification settings - Fork 55
[Example] GRPO on GSM8K with RULER reward #239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| # RL on GSM8K with RULER reward | ||
|
|
||
| This example shows a toy implementation of ART's [RULER](https://art.openpipe.ai/fundamentals/ruler) on GSM8k task and GRPO. | ||
|
|
||
| RULER (Relative Universal LLM-Elicited Rewards) is a general-purpose reward function that uses an LLM-as-judge to rank the rollouts for a given task. | ||
|
|
||
| https://github.com/OpenPipe/ART/blob/main/src/art/rewards/ruler.py | ||
|
|
||
|
|
||
| ## Configurations and Metrics | ||
|
|
||
| The config files are located in [`gsm8k_ruler.yaml`](gsm8k_ruler.yaml) and [`train_gsm8k_ruler.yaml`](train_gsm8k_ruler.yaml). | ||
|
|
||
| Some key configs in this example are: | ||
| * `default_workflow_type`: set to `math_ruler_workflow` | ||
| * `auxiliary_models`: LLM-as-a-judge for RULER; need to set `max_prompt_tokens`, `max_response_tokens`, `max_model_len` appropriately | ||
| * `std_threshold` for GRPO advantage: set to small value, filter out group of experiences with same rewards (e.g., when RULER fails to return valid scores, they are set to all zero) | ||
| * `sync_style`: use `dynamic_by_explorer`, due to filtering of experiences | ||
| * `lr`: set to small value (2e-6) for stability, as rewards can be noisy | ||
|
|
||
|
|
||
| Some important metrics to pay attention to are: | ||
| * `reward`: reward calculated by RULER | ||
| * `gold_reward`: sum of `accuracy_reward` and `format_reward`, rule-based calculation with ground truth (as in original GSM8k example) | ||
| * `judge_success`: whether RULER successfully returns a valid score | ||
| * `eval_accuracy`: accuracy on the evaluation set | ||
|
|
||
|
|
||
| ## Results | ||
| We show the results below: | ||
|
|
||
|
|
||
|  | ||
|
|
||
|  | ||
|
|
||
|  | ||
|
|
||
|  | ||
|
|
||
|
|
||
| Also, an example response from the judge LLM is shown below: | ||
| ``` | ||
| Let's evaluate each solution based on its accuracy, logical consistency, and adherence to the format: | ||
| 1. **Candidate Solution 1**: This solution incorrectly calculates the number of blue candles. It attempts to find the difference between the ratio components rather than using the correct proportional method. The rounding down to a whole number is unnecessary and incorrect. | ||
| - Score: 0.1 | ||
| 2. **Candidate Solution 2**: This solution correctly identifies the steps needed to find the number of blue candles, but the calculation at the end is incorrect. It mistakenly multiplies the number of sets by 3 instead of using the correct method to derive the number of blue candles. | ||
| - Score: 0.4 | ||
| 3. **Candidate Solution 3**: This solution correctly breaks down the problem, calculates the number of groups of 5 red candles, and then uses the ratio to find the number of blue candles. The final answer is correct. | ||
| - Score: 0.9 | ||
| 4. **Candidate Solution 4**: Similar to Solution 3, this solution correctly identifies the steps and performs the calculations accurately, leading to the right answer. | ||
| - Score: 0.9 | ||
| 5. **Candidate Solution 5**: This solution correctly follows the logical steps to find the number of blue candles and arrives at the correct answer. It repeats the answer tag, which isn't ideal but the reasoning and final result are accurate. | ||
| - Score: 0.9 | ||
| 6. **Candidate Solution 6**: This solution provides a clear and accurate explanation of the steps to find the number of blue candles and arrives at the correct answer. It explains the division and multiplication steps well. | ||
| - Score: 1.0 | ||
| 7. **Candidate Solution 7**: This solution misunderstands the ratio application, calculating only 3 groups of 5 red candles, which is incorrect. It results in an incorrect number of blue candles. | ||
| - Score: 0.1 | ||
| 8. **Candidate Solution 8**: This solution correctly applies the ratio and arrives at the right number of blue candles. It succinctly explains the calculation process. | ||
| - Score: 0.9 | ||
| [0.1, 0.4, 0.9, 0.9, 0.9, 1.0, 0.1, 0.9] | ||
| ``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| project: "Trinity-RFT-gsm8k-ruler" | ||
| name: "qwen2.5-1.5B-gsm8k-ruler" | ||
| checkpoint_root_dir: /PATH/TO/CHECKPOINT/ | ||
| algorithm: | ||
| algorithm_type: grpo | ||
| advantage_fn_args: | ||
| std_threshold: 0.0001 # effectively zero | ||
| repeat_times: 8 | ||
| model: | ||
| model_path: /PATH/TO/MODEL/ | ||
| max_response_tokens: 1024 | ||
| max_model_len: 1280 | ||
| cluster: | ||
| node_num: 1 | ||
| gpu_per_node: 8 | ||
| buffer: | ||
| total_epochs: 1 | ||
| batch_size: 96 | ||
| explorer_input: | ||
| taskset: | ||
| name: gsm8k | ||
| storage_type: file | ||
| path: 'openai/gsm8k' | ||
| subset_name: 'main' | ||
| split: 'train' | ||
| format: | ||
| prompt_key: 'question' | ||
| response_key: 'answer' | ||
| rollout_args: | ||
| temperature: 1.0 | ||
| eval_tasksets: | ||
| - name: gsm8k-eval | ||
| storage_type: file | ||
| path: 'openai/gsm8k' | ||
| subset_name: 'main' | ||
| split: 'test' | ||
| format: | ||
| prompt_key: 'question' | ||
| response_key: 'answer' | ||
| default_workflow_type: 'math_ruler_workflow' | ||
| trainer_input: | ||
| experience_buffer: | ||
| name: gsm8k_buffer | ||
| storage_type: queue | ||
| explorer: | ||
| eval_interval: 10 | ||
| runner_num: 32 | ||
| rollout_model: | ||
| engine_type: vllm_async | ||
| engine_num: 2 | ||
| tensor_parallel_size: 1 | ||
| enable_prefix_caching: false | ||
| enforce_eager: true | ||
| dtype: bfloat16 | ||
| seed: 42 | ||
| auxiliary_models: | ||
| - model_path: /PATH/TO/Qwen2.5-32B-Instruct | ||
| engine_num: 1 | ||
| tensor_parallel_size: 2 | ||
| enable_thinking: false | ||
| max_prompt_tokens: 12288 | ||
| max_response_tokens: 12288 | ||
| max_model_len: 16384 | ||
| synchronizer: | ||
| sync_style: dynamic_by_explorer | ||
| sync_method: 'nccl' | ||
| sync_interval: 5 | ||
| sync_timeout: 3600 | ||
| trainer: | ||
| trainer_type: 'verl' | ||
| trainer_config_path: 'examples/grpo_gsm8k_ruler/train_gsm8k_ruler.yaml' | ||
| save_interval: 100 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| actor_rollout_ref: | ||
| hybrid_engine: True | ||
| model: | ||
| external_lib: null | ||
| override_config: { } | ||
| enable_gradient_checkpointing: True | ||
| use_remove_padding: True # False | ||
| actor: | ||
| strategy: fsdp # This is for backward-compatibility | ||
| ppo_micro_batch_size_per_gpu: 4 | ||
| use_dynamic_bsz: True # False | ||
| ppo_max_token_len_per_gpu: 16384 | ||
| grad_clip: 1.0 | ||
| ppo_epochs: 1 | ||
| shuffle: False | ||
| ulysses_sequence_parallel_size: 1 # sp size | ||
| optim: | ||
| lr: 2e-6 | ||
| lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime | ||
| # min_lr_ratio: null # only useful for warmup with cosine | ||
| warmup_style: constant # select from constant/cosine | ||
| total_training_steps: -1 # must be override by program | ||
| fsdp_config: | ||
| wrap_policy: | ||
| # transformer_layer_cls_to_wrap: None | ||
| min_num_params: 0 | ||
| param_offload: False | ||
| optimizer_offload: False | ||
| fsdp_size: -1 | ||
| ref: | ||
| fsdp_config: | ||
| param_offload: False | ||
| wrap_policy: | ||
| # transformer_layer_cls_to_wrap: None | ||
| min_num_params: 0 | ||
| log_prob_micro_batch_size_per_gpu: 4 | ||
| log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} | ||
| log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} | ||
| ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size | ||
| trainer: | ||
| balance_batch: True | ||
| # total_training_steps: null | ||
| # auto: find the last ckpt to resume. If can't find, start from scratch | ||
| resume_mode: auto # or auto or resume_path if | ||
| default_hdfs_dir: null | ||
| remove_previous_ckpt_in_save: False | ||
| del_local_ckpt_after_load: False | ||
| val_before_train: False | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,152 @@ | ||
| # -*- coding: utf-8 -*- | ||
| """Math workflow with RULER.""" | ||
| import ast | ||
| from typing import Any, List, Optional, Tuple | ||
|
|
||
| import openai | ||
|
|
||
| from trinity.common.experience import Experience | ||
| from trinity.common.models.model import ModelWrapper | ||
| from trinity.common.rewards.math_reward import MathRewardFn | ||
| from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task | ||
| from trinity.utils.log import get_logger | ||
|
|
||
| logger = get_logger(__name__) | ||
|
|
||
|
|
||
| @WORKFLOWS.register_module("math_ruler_workflow") | ||
| class MathRULERWorkflow(SimpleWorkflow): | ||
| """A workflow for math with RULER reward function. | ||
|
|
||
| Modified from `MathWorkflow`. | ||
| Adapted from https://github.com/OpenPipe/ART/blob/main/src/art/rewards/ruler.py | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| task: Task, | ||
| model: ModelWrapper, | ||
| auxiliary_models: Optional[List[openai.OpenAI]] = None, | ||
| ): | ||
| super().__init__( | ||
| task=task, | ||
| model=model, | ||
| auxiliary_models=auxiliary_models, | ||
| ) | ||
|
|
||
| def reset(self, task: Task): | ||
| """ | ||
| Note that in this workflow, MathRewardFn is only used for calculating the 'golden reward', | ||
| whereasa the rewards used by RL training are calculated by RULER. | ||
| """ | ||
|
|
||
| if task.reward_fn is None: | ||
| task.reward_fn = MathRewardFn | ||
| if task.reward_fn == MathRewardFn and task.format_args.system_prompt is None: | ||
| task.format_args.system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., | ||
| <think> reasoning process here </think> | ||
| <answer> answer here </answer>. | ||
| """ | ||
| # call the SimpleWorkflow.reset | ||
| super().reset(task) | ||
|
|
||
| def run(self) -> List[Experience]: | ||
| """Modified from SimpleWorkflow.run""" | ||
|
|
||
| messages = self.format_messages() | ||
|
|
||
| self.logger.debug("start chat") | ||
| responses = self.model.chat(messages, **self.rollout_args) | ||
|
|
||
| for i, response in enumerate(responses): | ||
| gold_reward_dict = self.reward_fn( # type: ignore [misc] | ||
| response=response.response_text, # type: ignore [arg-type] | ||
| truth=self.truth, | ||
| ) | ||
|
|
||
| if response.metrics is None: | ||
| response.metrics = {} | ||
|
|
||
| response.metrics.update(gold_reward_dict) | ||
| gold_reward = sum(gold_reward_dict.values()) | ||
| response.metrics.update({"gold_reward": gold_reward}) | ||
| response.eid.run = i + self.run_id_base | ||
|
|
||
| # self.logger.debug( | ||
| # f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, gold_reward: {gold_reward}" | ||
| # ) | ||
|
|
||
| # === RULER scores as rewards === | ||
| assert ( | ||
| self.auxiliary_models is not None | ||
| ), "Current implementation of RULER requires that auxiliary_models is not None." | ||
| judge_success, ruler_scores = self.get_ruler_scores( | ||
| responses=responses, judger=self.auxiliary_models[0] | ||
| ) | ||
| for i, response in enumerate(responses): | ||
| response.reward = ruler_scores[i] | ||
| response.metrics.update({"judge_success": float(judge_success)}) | ||
|
|
||
| return responses | ||
|
|
||
| def get_ruler_scores( | ||
| self, responses: List[Experience], judger: Any | ||
| ) -> Tuple[bool, List[float]]: | ||
| """Get RULER scores""" | ||
|
|
||
| num_responses = len(responses) | ||
|
|
||
| # Step 1: format prompt for judge | ||
| ruler_system_prompt = f"You are a fair judge. The user will provide a question and {num_responses} candidate solutions to it. Your task is to compare the solutions, see how well they resolve the question, and assign a score within the range [0, 1] for each solution." | ||
|
|
||
| question_prompt = ( | ||
| f"Question: {self.task_desc}\n\n" | ||
| f"""Solution format requirement: first thinks about the reasoning process in the mind and then provides the final answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., | ||
| <think> reasoning process here </think> | ||
| <answer> answer here </answer>.""" | ||
| ) | ||
|
|
||
| solutions_prompt_parts = [ | ||
| f"Candidate solution {i + 1}: {response.response_text}" | ||
| for i, response in enumerate(responses) | ||
| ] | ||
| solutions_prompt = "\n\n".join(solutions_prompt_parts) | ||
|
|
||
| ruler_user_prompt = f""" | ||
| Below is a question and several candidate solutions. | ||
|
|
||
| {question_prompt} | ||
|
|
||
| {solutions_prompt} | ||
|
|
||
| Please assign a score within the range [0, 1] for each of them, reflecting how well they solve the question. | ||
| You may compare them against each other and think step by step before returning your final scores, but keep your reasoning process brief and concise when possible. | ||
|
|
||
| Conclude your response with a list of scores, in the following format: [score for solution 1, score for solution 2, ..., score for solution {num_responses + 1}] | ||
hiyuchang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
|
|
||
| # Step 2: invoke judger LLM | ||
| messages = [ | ||
| {"role": "system", "content": ruler_system_prompt}, | ||
| {"role": "user", "content": ruler_user_prompt}, | ||
| ] | ||
| completion = judger.chat.completions.create( | ||
| model=judger.model_path, messages=messages, stream=False | ||
| ) | ||
| judger_response = completion.choices[0].message.content | ||
| logger.info(f"LLM judge response: {judger_response}") | ||
|
|
||
| # Step 3: extract scores from judger's response | ||
| idx1, idx2 = judger_response.rfind("["), judger_response.rfind("]") | ||
| if (idx1 == -1) or (idx2 == -1) or (idx1 > idx2): | ||
| logger.warning("Unable to extract a list from judger response, set scores to all zero.") | ||
| return False, [0.0 for _ in range(num_responses)] | ||
| lst_as_str = judger_response[idx1 : (idx2 + 1)] | ||
| try: | ||
| scores = ast.literal_eval(lst_as_str) | ||
| scores = [max(0.0, min(1.0, score)) for score in scores] # clip to range [0, 1] | ||
| return True, scores | ||
| except Exception: | ||
hiyuchang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| logger.warning("Unable to parse the list in judger response, set scores to all zero.") | ||
| return False, [0.0 for _ in range(num_responses)] | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.