Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/sphinx_doc/assets/gsm8k_ruler_reward.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
62 changes: 62 additions & 0 deletions examples/grpo_gsm8k_ruler/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# RL on GSM8K with RULER reward

This example shows a toy implementation of ART's [RULER](https://art.openpipe.ai/fundamentals/ruler) on GSM8k task and GRPO.

RULER (Relative Universal LLM-Elicited Rewards) is a general-purpose reward function that uses an LLM-as-judge to rank the rollouts for a given task.

https://github.com/OpenPipe/ART/blob/main/src/art/rewards/ruler.py


## Configurations and Metrics

The config files are located in [`gsm8k_ruler.yaml`](gsm8k_ruler.yaml) and [`train_gsm8k_ruler.yaml`](train_gsm8k_ruler.yaml).

Some key configs in this example are:
* `default_workflow_type`: set to `math_ruler_workflow`
* `auxiliary_models`: LLM-as-a-judge for RULER; need to set `max_prompt_tokens`, `max_response_tokens`, `max_model_len` appropriately
* `std_threshold` for GRPO advantage: set to small value, filter out group of experiences with same rewards (e.g., when RULER fails to return valid scores, they are set to all zero)
* `sync_style`: use `dynamic_by_explorer`, due to filtering of experiences
* `lr`: set to small value (2e-6) for stability, as rewards can be noisy


Some important metrics to pay attention to are:
* `reward`: reward calculated by RULER
* `gold_reward`: sum of `accuracy_reward` and `format_reward`, rule-based calculation with ground truth (as in original GSM8k example)
* `judge_success`: whether RULER successfully returns a valid score
* `eval_accuracy`: accuracy on the evaluation set


## Results
We show the results below:


![reward](../../docs/sphinx_doc/assets/gsm8k_ruler_reward.png)

![gold_reward](../../docs/sphinx_doc/assets/gsm8k_ruler_gold_reward.png)

![eval_accuracy](../../docs/sphinx_doc/assets/gsm8k_ruler_eval_accuracy.png)

![judge_success](../../docs/sphinx_doc/assets/gsm8k_ruler_judge_success.png)


Also, an example response from the judge LLM is shown below:
```
Let's evaluate each solution based on its accuracy, logical consistency, and adherence to the format:
1. **Candidate Solution 1**: This solution incorrectly calculates the number of blue candles. It attempts to find the difference between the ratio components rather than using the correct proportional method. The rounding down to a whole number is unnecessary and incorrect.
- Score: 0.1
2. **Candidate Solution 2**: This solution correctly identifies the steps needed to find the number of blue candles, but the calculation at the end is incorrect. It mistakenly multiplies the number of sets by 3 instead of using the correct method to derive the number of blue candles.
- Score: 0.4
3. **Candidate Solution 3**: This solution correctly breaks down the problem, calculates the number of groups of 5 red candles, and then uses the ratio to find the number of blue candles. The final answer is correct.
- Score: 0.9
4. **Candidate Solution 4**: Similar to Solution 3, this solution correctly identifies the steps and performs the calculations accurately, leading to the right answer.
- Score: 0.9
5. **Candidate Solution 5**: This solution correctly follows the logical steps to find the number of blue candles and arrives at the correct answer. It repeats the answer tag, which isn't ideal but the reasoning and final result are accurate.
- Score: 0.9
6. **Candidate Solution 6**: This solution provides a clear and accurate explanation of the steps to find the number of blue candles and arrives at the correct answer. It explains the division and multiplication steps well.
- Score: 1.0
7. **Candidate Solution 7**: This solution misunderstands the ratio application, calculating only 3 groups of 5 red candles, which is incorrect. It results in an incorrect number of blue candles.
- Score: 0.1
8. **Candidate Solution 8**: This solution correctly applies the ratio and arrives at the right number of blue candles. It succinctly explains the calculation process.
- Score: 0.9
[0.1, 0.4, 0.9, 0.9, 0.9, 1.0, 0.1, 0.9]
```
72 changes: 72 additions & 0 deletions examples/grpo_gsm8k_ruler/gsm8k_ruler.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
project: "Trinity-RFT-gsm8k-ruler"
name: "qwen2.5-1.5B-gsm8k-ruler"
checkpoint_root_dir: /PATH/TO/CHECKPOINT/
algorithm:
algorithm_type: grpo
advantage_fn_args:
std_threshold: 0.0001 # effectively zero
repeat_times: 8
model:
model_path: /PATH/TO/MODEL/
max_response_tokens: 1024
max_model_len: 1280
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_epochs: 1
batch_size: 96
explorer_input:
taskset:
name: gsm8k
storage_type: file
path: 'openai/gsm8k'
subset_name: 'main'
split: 'train'
format:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
temperature: 1.0
eval_tasksets:
- name: gsm8k-eval
storage_type: file
path: 'openai/gsm8k'
subset_name: 'main'
split: 'test'
format:
prompt_key: 'question'
response_key: 'answer'
default_workflow_type: 'math_ruler_workflow'
trainer_input:
experience_buffer:
name: gsm8k_buffer
storage_type: queue
explorer:
eval_interval: 10
runner_num: 32
rollout_model:
engine_type: vllm_async
engine_num: 2
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
auxiliary_models:
- model_path: /PATH/TO/Qwen2.5-32B-Instruct
engine_num: 1
tensor_parallel_size: 2
enable_thinking: false
max_prompt_tokens: 12288
max_response_tokens: 12288
max_model_len: 16384
synchronizer:
sync_style: dynamic_by_explorer
sync_method: 'nccl'
sync_interval: 5
sync_timeout: 3600
trainer:
trainer_type: 'verl'
trainer_config_path: 'examples/grpo_gsm8k_ruler/train_gsm8k_ruler.yaml'
save_interval: 100
48 changes: 48 additions & 0 deletions examples/grpo_gsm8k_ruler/train_gsm8k_ruler.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
actor_rollout_ref:
hybrid_engine: True
model:
external_lib: null
override_config: { }
enable_gradient_checkpointing: True
use_remove_padding: True # False
actor:
strategy: fsdp # This is for backward-compatibility
ppo_micro_batch_size_per_gpu: 4
use_dynamic_bsz: True # False
ppo_max_token_len_per_gpu: 16384
grad_clip: 1.0
ppo_epochs: 1
shuffle: False
ulysses_sequence_parallel_size: 1 # sp size
optim:
lr: 2e-6
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
# min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
fsdp_config:
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
param_offload: False
optimizer_offload: False
fsdp_size: -1
ref:
fsdp_config:
param_offload: False
wrap_policy:
# transformer_layer_cls_to_wrap: None
min_num_params: 0
log_prob_micro_batch_size_per_gpu: 4
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
trainer:
balance_batch: True
# total_training_steps: null
# auto: find the last ckpt to resume. If can't find, start from scratch
resume_mode: auto # or auto or resume_path if
default_hdfs_dir: null
remove_previous_ckpt_in_save: False
del_local_ckpt_after_load: False
val_before_train: False
2 changes: 2 additions & 0 deletions trinity/common/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .envs.webshop.webshop_workflow import WebShopWorkflow
from .eval_workflow import MathEvalWorkflow
from .math_rm_workflow import MathRMWorkflow
from .math_ruler_workflow import MathRULERWorkflow
from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow, Task, Workflow

__all__ = [
Expand All @@ -31,4 +32,5 @@
"MathEvalWorkflow",
"AgentScopeReactV2MathWorkflow",
"EmailSearchWorkflow",
"MathRULERWorkflow",
]
152 changes: 152 additions & 0 deletions trinity/common/workflows/math_ruler_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# -*- coding: utf-8 -*-
"""Math workflow with RULER."""
import ast
from typing import Any, List, Optional, Tuple

import openai

from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.common.rewards.math_reward import MathRewardFn
from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task
from trinity.utils.log import get_logger

logger = get_logger(__name__)


@WORKFLOWS.register_module("math_ruler_workflow")
class MathRULERWorkflow(SimpleWorkflow):
"""A workflow for math with RULER reward function.

Modified from `MathWorkflow`.
Adapted from https://github.com/OpenPipe/ART/blob/main/src/art/rewards/ruler.py
"""

def __init__(
self,
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
):
super().__init__(
task=task,
model=model,
auxiliary_models=auxiliary_models,
)

def reset(self, task: Task):
"""
Note that in this workflow, MathRewardFn is only used for calculating the 'golden reward',
whereasa the rewards used by RL training are calculated by RULER.
"""

if task.reward_fn is None:
task.reward_fn = MathRewardFn
if task.reward_fn == MathRewardFn and task.format_args.system_prompt is None:
task.format_args.system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e.,
<think> reasoning process here </think>
<answer> answer here </answer>.
"""
# call the SimpleWorkflow.reset
super().reset(task)

def run(self) -> List[Experience]:
"""Modified from SimpleWorkflow.run"""

messages = self.format_messages()

self.logger.debug("start chat")
responses = self.model.chat(messages, **self.rollout_args)

for i, response in enumerate(responses):
gold_reward_dict = self.reward_fn( # type: ignore [misc]
response=response.response_text, # type: ignore [arg-type]
truth=self.truth,
)

if response.metrics is None:
response.metrics = {}

response.metrics.update(gold_reward_dict)
gold_reward = sum(gold_reward_dict.values())
response.metrics.update({"gold_reward": gold_reward})
response.eid.run = i + self.run_id_base

# self.logger.debug(
# f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, gold_reward: {gold_reward}"
# )

# === RULER scores as rewards ===
assert (
self.auxiliary_models is not None
), "Current implementation of RULER requires that auxiliary_models is not None."
judge_success, ruler_scores = self.get_ruler_scores(
responses=responses, judger=self.auxiliary_models[0]
)
for i, response in enumerate(responses):
response.reward = ruler_scores[i]
response.metrics.update({"judge_success": float(judge_success)})

return responses

def get_ruler_scores(
self, responses: List[Experience], judger: Any
) -> Tuple[bool, List[float]]:
"""Get RULER scores"""

num_responses = len(responses)

# Step 1: format prompt for judge
ruler_system_prompt = f"You are a fair judge. The user will provide a question and {num_responses} candidate solutions to it. Your task is to compare the solutions, see how well they resolve the question, and assign a score within the range [0, 1] for each solution."

question_prompt = (
f"Question: {self.task_desc}\n\n"
f"""Solution format requirement: first thinks about the reasoning process in the mind and then provides the final answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e.,
<think> reasoning process here </think>
<answer> answer here </answer>."""
)

solutions_prompt_parts = [
f"Candidate solution {i + 1}: {response.response_text}"
for i, response in enumerate(responses)
]
solutions_prompt = "\n\n".join(solutions_prompt_parts)

ruler_user_prompt = f"""
Below is a question and several candidate solutions.

{question_prompt}

{solutions_prompt}

Please assign a score within the range [0, 1] for each of them, reflecting how well they solve the question.
You may compare them against each other and think step by step before returning your final scores, but keep your reasoning process brief and concise when possible.

Conclude your response with a list of scores, in the following format: [score for solution 1, score for solution 2, ..., score for solution {num_responses + 1}]
"""

# Step 2: invoke judger LLM
messages = [
{"role": "system", "content": ruler_system_prompt},
{"role": "user", "content": ruler_user_prompt},
]
completion = judger.chat.completions.create(
model=judger.model_path, messages=messages, stream=False
)
judger_response = completion.choices[0].message.content
logger.info(f"LLM judge response: {judger_response}")

# Step 3: extract scores from judger's response
idx1, idx2 = judger_response.rfind("["), judger_response.rfind("]")
if (idx1 == -1) or (idx2 == -1) or (idx1 > idx2):
logger.warning("Unable to extract a list from judger response, set scores to all zero.")
return False, [0.0 for _ in range(num_responses)]
lst_as_str = judger_response[idx1 : (idx2 + 1)]
try:
scores = ast.literal_eval(lst_as_str)
scores = [max(0.0, min(1.0, score)) for score in scores] # clip to range [0, 1]
return True, scores
except Exception:
logger.warning("Unable to parse the list in judger response, set scores to all zero.")
return False, [0.0 for _ in range(num_responses)]