diff --git a/docs/sphinx_doc/assets/gsm8k_ruler_eval_accuracy.png b/docs/sphinx_doc/assets/gsm8k_ruler_eval_accuracy.png new file mode 100644 index 0000000000..0eac700fc0 Binary files /dev/null and b/docs/sphinx_doc/assets/gsm8k_ruler_eval_accuracy.png differ diff --git a/docs/sphinx_doc/assets/gsm8k_ruler_gold_reward.png b/docs/sphinx_doc/assets/gsm8k_ruler_gold_reward.png new file mode 100644 index 0000000000..e353864253 Binary files /dev/null and b/docs/sphinx_doc/assets/gsm8k_ruler_gold_reward.png differ diff --git a/docs/sphinx_doc/assets/gsm8k_ruler_judge_success.png b/docs/sphinx_doc/assets/gsm8k_ruler_judge_success.png new file mode 100644 index 0000000000..c14d7d0429 Binary files /dev/null and b/docs/sphinx_doc/assets/gsm8k_ruler_judge_success.png differ diff --git a/docs/sphinx_doc/assets/gsm8k_ruler_reward.png b/docs/sphinx_doc/assets/gsm8k_ruler_reward.png new file mode 100644 index 0000000000..ec3d5cb2cb Binary files /dev/null and b/docs/sphinx_doc/assets/gsm8k_ruler_reward.png differ diff --git a/examples/grpo_gsm8k_ruler/README.md b/examples/grpo_gsm8k_ruler/README.md new file mode 100644 index 0000000000..0f1ac1e83a --- /dev/null +++ b/examples/grpo_gsm8k_ruler/README.md @@ -0,0 +1,62 @@ +# RL on GSM8K with RULER reward + +This example shows a toy implementation of ART's [RULER](https://art.openpipe.ai/fundamentals/ruler) on GSM8k task and GRPO. + +RULER (Relative Universal LLM-Elicited Rewards) is a general-purpose reward function that uses an LLM-as-judge to rank the rollouts for a given task. + +https://github.com/OpenPipe/ART/blob/main/src/art/rewards/ruler.py + + +## Configurations and Metrics + +The config files are located in [`gsm8k_ruler.yaml`](gsm8k_ruler.yaml) and [`train_gsm8k_ruler.yaml`](train_gsm8k_ruler.yaml). + +Some key configs in this example are: +* `default_workflow_type`: set to `math_ruler_workflow` +* `auxiliary_models`: LLM-as-a-judge for RULER; need to set `max_prompt_tokens`, `max_response_tokens`, `max_model_len` appropriately +* `std_threshold` for GRPO advantage: set to small value, filter out group of experiences with same rewards (e.g., when RULER fails to return valid scores, they are set to all zero) +* `sync_style`: use `dynamic_by_explorer`, due to filtering of experiences +* `lr`: set to small value (2e-6) for stability, as rewards can be noisy + + +Some important metrics to pay attention to are: +* `reward`: reward calculated by RULER +* `gold_reward`: sum of `accuracy_reward` and `format_reward`, rule-based calculation with ground truth (as in original GSM8k example) +* `judge_success`: whether RULER successfully returns a valid score +* `eval_accuracy`: accuracy on the evaluation set + + +## Results +We show the results below: + + +![reward](../../docs/sphinx_doc/assets/gsm8k_ruler_reward.png) + +![gold_reward](../../docs/sphinx_doc/assets/gsm8k_ruler_gold_reward.png) + +![eval_accuracy](../../docs/sphinx_doc/assets/gsm8k_ruler_eval_accuracy.png) + +![judge_success](../../docs/sphinx_doc/assets/gsm8k_ruler_judge_success.png) + + +Also, an example response from the judge LLM is shown below: +``` +Let's evaluate each solution based on its accuracy, logical consistency, and adherence to the format: +1. **Candidate Solution 1**: This solution incorrectly calculates the number of blue candles. It attempts to find the difference between the ratio components rather than using the correct proportional method. The rounding down to a whole number is unnecessary and incorrect. + - Score: 0.1 +2. **Candidate Solution 2**: This solution correctly identifies the steps needed to find the number of blue candles, but the calculation at the end is incorrect. It mistakenly multiplies the number of sets by 3 instead of using the correct method to derive the number of blue candles. + - Score: 0.4 +3. **Candidate Solution 3**: This solution correctly breaks down the problem, calculates the number of groups of 5 red candles, and then uses the ratio to find the number of blue candles. The final answer is correct. + - Score: 0.9 +4. **Candidate Solution 4**: Similar to Solution 3, this solution correctly identifies the steps and performs the calculations accurately, leading to the right answer. + - Score: 0.9 +5. **Candidate Solution 5**: This solution correctly follows the logical steps to find the number of blue candles and arrives at the correct answer. It repeats the answer tag, which isn't ideal but the reasoning and final result are accurate. + - Score: 0.9 +6. **Candidate Solution 6**: This solution provides a clear and accurate explanation of the steps to find the number of blue candles and arrives at the correct answer. It explains the division and multiplication steps well. + - Score: 1.0 +7. **Candidate Solution 7**: This solution misunderstands the ratio application, calculating only 3 groups of 5 red candles, which is incorrect. It results in an incorrect number of blue candles. + - Score: 0.1 +8. **Candidate Solution 8**: This solution correctly applies the ratio and arrives at the right number of blue candles. It succinctly explains the calculation process. + - Score: 0.9 +[0.1, 0.4, 0.9, 0.9, 0.9, 1.0, 0.1, 0.9] +``` diff --git a/examples/grpo_gsm8k_ruler/gsm8k_ruler.yaml b/examples/grpo_gsm8k_ruler/gsm8k_ruler.yaml new file mode 100644 index 0000000000..11ef8dd1ea --- /dev/null +++ b/examples/grpo_gsm8k_ruler/gsm8k_ruler.yaml @@ -0,0 +1,72 @@ +project: "Trinity-RFT-gsm8k-ruler" +name: "qwen2.5-1.5B-gsm8k-ruler" +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +algorithm: + algorithm_type: grpo + advantage_fn_args: + std_threshold: 0.0001 # effectively zero + repeat_times: 8 +model: + model_path: /PATH/TO/MODEL/ + max_response_tokens: 1024 + max_model_len: 1280 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 1 + batch_size: 96 + explorer_input: + taskset: + name: gsm8k + storage_type: file + path: 'openai/gsm8k' + subset_name: 'main' + split: 'train' + format: + prompt_key: 'question' + response_key: 'answer' + rollout_args: + temperature: 1.0 + eval_tasksets: + - name: gsm8k-eval + storage_type: file + path: 'openai/gsm8k' + subset_name: 'main' + split: 'test' + format: + prompt_key: 'question' + response_key: 'answer' + default_workflow_type: 'math_ruler_workflow' + trainer_input: + experience_buffer: + name: gsm8k_buffer + storage_type: queue +explorer: + eval_interval: 10 + runner_num: 32 + rollout_model: + engine_type: vllm_async + engine_num: 2 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 + auxiliary_models: + - model_path: /PATH/TO/Qwen2.5-32B-Instruct + engine_num: 1 + tensor_parallel_size: 2 + enable_thinking: false + max_prompt_tokens: 12288 + max_response_tokens: 12288 + max_model_len: 16384 +synchronizer: + sync_style: dynamic_by_explorer + sync_method: 'nccl' + sync_interval: 5 + sync_timeout: 3600 +trainer: + trainer_type: 'verl' + trainer_config_path: 'examples/grpo_gsm8k_ruler/train_gsm8k_ruler.yaml' + save_interval: 100 diff --git a/examples/grpo_gsm8k_ruler/train_gsm8k_ruler.yaml b/examples/grpo_gsm8k_ruler/train_gsm8k_ruler.yaml new file mode 100644 index 0000000000..4cd4d0beba --- /dev/null +++ b/examples/grpo_gsm8k_ruler/train_gsm8k_ruler.yaml @@ -0,0 +1,48 @@ +actor_rollout_ref: + hybrid_engine: True + model: + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True # False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_micro_batch_size_per_gpu: 4 + use_dynamic_bsz: True # False + ppo_max_token_len_per_gpu: 16384 + grad_clip: 1.0 + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 2e-6 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size_per_gpu: 4 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size +trainer: + balance_batch: True + # total_training_steps: null + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + val_before_train: False diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py index 971210ebe9..12d8c4fe11 100644 --- a/trinity/common/workflows/__init__.py +++ b/trinity/common/workflows/__init__.py @@ -11,6 +11,7 @@ from .envs.webshop.webshop_workflow import WebShopWorkflow from .eval_workflow import MathEvalWorkflow from .math_rm_workflow import MathRMWorkflow +from .math_ruler_workflow import MathRULERWorkflow from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow, Task, Workflow __all__ = [ @@ -31,4 +32,5 @@ "MathEvalWorkflow", "AgentScopeReactV2MathWorkflow", "EmailSearchWorkflow", + "MathRULERWorkflow", ] diff --git a/trinity/common/workflows/math_ruler_workflow.py b/trinity/common/workflows/math_ruler_workflow.py new file mode 100644 index 0000000000..6345e81018 --- /dev/null +++ b/trinity/common/workflows/math_ruler_workflow.py @@ -0,0 +1,152 @@ +# -*- coding: utf-8 -*- +"""Math workflow with RULER.""" +import ast +from typing import Any, List, Optional, Tuple + +import openai + +from trinity.common.experience import Experience +from trinity.common.models.model import ModelWrapper +from trinity.common.rewards.math_reward import MathRewardFn +from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task +from trinity.utils.log import get_logger + +logger = get_logger(__name__) + + +@WORKFLOWS.register_module("math_ruler_workflow") +class MathRULERWorkflow(SimpleWorkflow): + """A workflow for math with RULER reward function. + + Modified from `MathWorkflow`. + Adapted from https://github.com/OpenPipe/ART/blob/main/src/art/rewards/ruler.py + """ + + def __init__( + self, + *, + task: Task, + model: ModelWrapper, + auxiliary_models: Optional[List[openai.OpenAI]] = None, + ): + super().__init__( + task=task, + model=model, + auxiliary_models=auxiliary_models, + ) + + def reset(self, task: Task): + """ + Note that in this workflow, MathRewardFn is only used for calculating the 'golden reward', + whereasa the rewards used by RL training are calculated by RULER. + """ + + if task.reward_fn is None: + task.reward_fn = MathRewardFn + if task.reward_fn == MathRewardFn and task.format_args.system_prompt is None: + task.format_args.system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., + reasoning process here + answer here . +""" + # call the SimpleWorkflow.reset + super().reset(task) + + def run(self) -> List[Experience]: + """Modified from SimpleWorkflow.run""" + + messages = self.format_messages() + + self.logger.debug("start chat") + responses = self.model.chat(messages, **self.rollout_args) + + for i, response in enumerate(responses): + gold_reward_dict = self.reward_fn( # type: ignore [misc] + response=response.response_text, # type: ignore [arg-type] + truth=self.truth, + ) + + if response.metrics is None: + response.metrics = {} + + response.metrics.update(gold_reward_dict) + gold_reward = sum(gold_reward_dict.values()) + response.metrics.update({"gold_reward": gold_reward}) + response.eid.run = i + self.run_id_base + + # self.logger.debug( + # f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, gold_reward: {gold_reward}" + # ) + + # === RULER scores as rewards === + assert ( + self.auxiliary_models is not None + ), "Current implementation of RULER requires that auxiliary_models is not None." + judge_success, ruler_scores = self.get_ruler_scores( + responses=responses, judger=self.auxiliary_models[0] + ) + for i, response in enumerate(responses): + response.reward = ruler_scores[i] + response.metrics.update({"judge_success": float(judge_success)}) + + return responses + + def get_ruler_scores( + self, responses: List[Experience], judger: Any + ) -> Tuple[bool, List[float]]: + """Get RULER scores""" + + num_responses = len(responses) + + # Step 1: format prompt for judge + ruler_system_prompt = f"You are a fair judge. The user will provide a question and {num_responses} candidate solutions to it. Your task is to compare the solutions, see how well they resolve the question, and assign a score within the range [0, 1] for each solution." + + question_prompt = ( + f"Question: {self.task_desc}\n\n" + f"""Solution format requirement: first thinks about the reasoning process in the mind and then provides the final answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., + reasoning process here + answer here .""" + ) + + solutions_prompt_parts = [ + f"Candidate solution {i + 1}: {response.response_text}" + for i, response in enumerate(responses) + ] + solutions_prompt = "\n\n".join(solutions_prompt_parts) + + ruler_user_prompt = f""" +Below is a question and several candidate solutions. + +{question_prompt} + +{solutions_prompt} + +Please assign a score within the range [0, 1] for each of them, reflecting how well they solve the question. +You may compare them against each other and think step by step before returning your final scores, but keep your reasoning process brief and concise when possible. + +Conclude your response with a list of scores, in the following format: [score for solution 1, score for solution 2, ..., score for solution {num_responses + 1}] +""" + + # Step 2: invoke judger LLM + messages = [ + {"role": "system", "content": ruler_system_prompt}, + {"role": "user", "content": ruler_user_prompt}, + ] + completion = judger.chat.completions.create( + model=judger.model_path, messages=messages, stream=False + ) + judger_response = completion.choices[0].message.content + logger.info(f"LLM judge response: {judger_response}") + + # Step 3: extract scores from judger's response + idx1, idx2 = judger_response.rfind("["), judger_response.rfind("]") + if (idx1 == -1) or (idx2 == -1) or (idx1 > idx2): + logger.warning("Unable to extract a list from judger response, set scores to all zero.") + return False, [0.0 for _ in range(num_responses)] + lst_as_str = judger_response[idx1 : (idx2 + 1)] + try: + scores = ast.literal_eval(lst_as_str) + scores = [max(0.0, min(1.0, score)) for score in scores] # clip to range [0, 1] + return True, scores + except Exception: + logger.warning("Unable to parse the list in judger response, set scores to all zero.") + return False, [0.0 for _ in range(num_responses)]