From b8c262a9f0bd824f81961b8478d3c7c238840bf7 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Thu, 10 Jul 2025 13:58:13 +0800 Subject: [PATCH 01/10] add dapo workflow and reward --- examples/dapo_math/dapo.yaml | 73 +++++++++++++++++++++++ examples/dapo_math/train_dapo.yaml | 48 +++++++++++++++ trinity/common/rewards/__init__.py | 2 + trinity/common/rewards/dapo_reward.py | 69 +++++++++++++++++++++ trinity/common/workflows/__init__.py | 2 + trinity/common/workflows/dapo_workflow.py | 65 ++++++++++++++++++++ 6 files changed, 259 insertions(+) create mode 100644 examples/dapo_math/dapo.yaml create mode 100644 examples/dapo_math/train_dapo.yaml create mode 100644 trinity/common/rewards/dapo_reward.py create mode 100644 trinity/common/workflows/dapo_workflow.py diff --git a/examples/dapo_math/dapo.yaml b/examples/dapo_math/dapo.yaml new file mode 100644 index 0000000000..dee8e5461f --- /dev/null +++ b/examples/dapo_math/dapo.yaml @@ -0,0 +1,73 @@ +project: Trinity-RFT-example +name: dapo +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +model: + model_path: /PATH/TO/MODEL/ +algorithm: + algorithm_type: grpo + repeat_times: 16 + policy_loss_fn_args: + clip_range_low: 0.2 + clip_range_high: 0.28 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 1 + batch_size: 32 + max_retry_times: 3 + max_retry_interval: 1 + explorer_input: + taskset: + name: dapo-math + storage_type: file + path: /PATH/TO/Processed-DAPO-Math-17k/ + format: + prompt_key: 'prompt' + response_key: 'ground_truth' + rollout_args: + temperature: 1.0 + logprobs: 0 + workflow_args: + enable_overlong_penalty: true + penalty_factor: 1.0 + max_response_length: 20480 + cache_length: 4096 + eval_tasksets: + - name: AIME2024 + storage_type: file + path: /PATH/TO/AIME2024/ + split: 'test' + format: + prompt_key: 'question' + response_key: 'answer' + rollout_args: + n: 32 + temperature: 1.0 + top_p: 0.7 + default_workflow_type: 'math_dapo_workflow' + trainer_input: + experience_buffer: + name: math_buffer + storage_type: queue +explorer: + eval_interval: 10 + runner_num: 32 + rollout_model: + engine_type: vllm_async + engine_num: 4 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + max_prompt_tokens: 1024 + max_response_tokens: 20480 + seed: 42 +synchronizer: + sync_method: 'nccl' + sync_interval: 16 + sync_timeout: 1200 +trainer: + trainer_type: 'verl' + trainer_config_path: 'examples/dapo_math/train_dapo.yaml' + save_interval: 100 diff --git a/examples/dapo_math/train_dapo.yaml b/examples/dapo_math/train_dapo.yaml new file mode 100644 index 0000000000..32831b03fe --- /dev/null +++ b/examples/dapo_math/train_dapo.yaml @@ -0,0 +1,48 @@ +actor_rollout_ref: + hybrid_engine: True + model: + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True # False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_micro_batch_size_per_gpu: 4 + use_dynamic_bsz: True # False + ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-6 + lr_warmup_steps: 20 # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size_per_gpu: 16 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + +trainer: + balance_batch: True + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + val_before_train: False diff --git a/trinity/common/rewards/__init__.py b/trinity/common/rewards/__init__.py index f20abc2748..08379befda 100644 --- a/trinity/common/rewards/__init__.py +++ b/trinity/common/rewards/__init__.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- """Reward functions for RFT""" +from .dapo_reward import MathDAPORewardFn from .reward_fn import REWARD_FUNCTIONS, AccuracyReward, FormatReward, RewardFn __all__ = [ @@ -8,4 +9,5 @@ "REWARD_FUNCTIONS", "AccuracyReward", "FormatReward", + "MathDAPORewardFn", ] diff --git a/trinity/common/rewards/dapo_reward.py b/trinity/common/rewards/dapo_reward.py new file mode 100644 index 0000000000..3ebe87b92f --- /dev/null +++ b/trinity/common/rewards/dapo_reward.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +from typing import Optional, Union + +import torch + +from trinity.common.rewards import REWARD_FUNCTIONS, RewardFn +from trinity.utils.eval_utils import compute_score +from trinity.utils.log import get_logger + +logger = get_logger(__name__) + + +@REWARD_FUNCTIONS.register_module("math_dapo_reward") +class MathDAPORewardFn(RewardFn): + """A reward function that follows the definition in DAPO for math task.""" + + def __init__( + self, + enable_overlong_penalty: Optional[bool] = None, + penalty_factor: Optional[float] = None, + max_response_length: Optional[int] = None, + cache_length: Optional[int] = None, + ) -> None: + self.enable_overlong_penalty = enable_overlong_penalty + self.penalty_factor = penalty_factor + self.max_response_length = max_response_length + self.cache_length = cache_length + + def __call__( # type: ignore + self, + response: str, + response_token: torch.Tensor, + truth: Optional[str] = None, + return_dict: Optional[bool] = False, + ) -> Union[float, dict]: + accuracy_score = compute_score(response, truth) + + format_score = 0.0 + + if self.enable_overlong_penalty: + format_score = self.compute_overlong_penalty(response_token) + + if return_dict: + return { + "accuracy": accuracy_score, + "format_score": format_score, + } + + return accuracy_score + format_score + + def compute_overlong_penalty(self, response_token): + assert ( + self.max_response_length is not None + and self.cache_length is not None + and self.penalty_factor is not None + ), "When enable_overlong_penalty = true, max_response_length, penalty_factor, cache_length must be set" + assert ( + self.max_response_length > self.cache_length + ), "max_response_length must be greater than cache_length" + + response_len = len(response_token) + excepted_len = self.max_response_length - self.cache_length + + if response_len < excepted_len: + return 0.0 + elif response_len > self.max_response_length: + return -1 + else: + return (excepted_len - response_len) / self.cache_length * self.penalty_factor diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py index 9d54f108d0..c5054eb276 100644 --- a/trinity/common/workflows/__init__.py +++ b/trinity/common/workflows/__init__.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- """Workflow module""" from .customized_math_workflows import MathBoxedWorkflow +from .dapo_workflow import MathDAPOWorkflow from .envs.alfworld.alfworld_workflow import AlfworldWorkflow from .envs.sciworld.sciworld_workflow import SciWorldWorkflow from .envs.webshop.webshop_workflow import WebShopWorkflow @@ -16,4 +17,5 @@ "AlfworldWorkflow", "SciWorldWorkflow", "MathBoxedWorkflow", + "MathDAPOWorkflow", ] diff --git a/trinity/common/workflows/dapo_workflow.py b/trinity/common/workflows/dapo_workflow.py new file mode 100644 index 0000000000..3a64af2a2c --- /dev/null +++ b/trinity/common/workflows/dapo_workflow.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +"""We include the DAPO math workflows in this file.""" + +from dataclasses import asdict +from typing import List + +from trinity.common.experience import Experience +from trinity.common.rewards.dapo_reward import MathDAPORewardFn +from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task +from trinity.utils.log import get_logger + +logger = get_logger(__name__) + + +@WORKFLOWS.register_module("math_dapo_workflow") +class MathDAPOWorkflow(SimpleWorkflow): + """A workflow for math tasks as introduced in DAPO.""" + + def reset(self, task: Task): + self.format_args = task.format_args + + self.raw_task = task.raw_task + self.task_desc = task.task_desc + self.truth = task.truth + + rollout_args = asdict(task.rollout_args) + self.rollout_args = rollout_args + self.is_eval = task.is_eval + + self.workflow_args = task.workflow_args + + self.reward_fn = MathDAPORewardFn( + enable_overlong_penalty=self.workflow_args.get("enable_overlong_penalty", None), + penalty_factor=self.workflow_args.get("penalty_factor", None), + max_response_length=self.workflow_args.get("max_response_length", None), + cache_length=self.workflow_args.get("cache_length", None), + ) + + def run(self) -> List[Experience]: + messages = self.format_messages() + + logger.debug("start chat") + responses = self.model.chat(messages, **self.rollout_args) + + for response in responses: + reward = self.reward_fn( # type: ignore # TODO: fix type + response=response.response_text, # type: ignore [arg-type] + truth=self.truth, + return_dict=self.is_eval, + response_token=response.tokens[response.prompt_length :], + ) + logger.debug( + f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" + ) + if isinstance(reward, dict): + if response.metrics is None: + response.metrics = {} + response.metrics.update(reward) + reward = sum(reward.values()) + response.reward = reward + return responses + + def format_messages(self): + messages = [{"role": "user", "content": self.task_desc}] + return messages From b9076b7a0a98e6c174029cc3ef5fd0ae9d03be7f Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Thu, 10 Jul 2025 15:48:55 +0800 Subject: [PATCH 02/10] fix bug --- trinity/common/rewards/__init__.py | 9 +++++++-- trinity/common/rewards/dapo_reward.py | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/trinity/common/rewards/__init__.py b/trinity/common/rewards/__init__.py index 08379befda..c7d870ddcf 100644 --- a/trinity/common/rewards/__init__.py +++ b/trinity/common/rewards/__init__.py @@ -1,8 +1,13 @@ # -*- coding: utf-8 -*- """Reward functions for RFT""" -from .dapo_reward import MathDAPORewardFn -from .reward_fn import REWARD_FUNCTIONS, AccuracyReward, FormatReward, RewardFn +from trinity.common.rewards.dapo_reward import MathDAPORewardFn +from trinity.common.rewards.reward_fn import ( + REWARD_FUNCTIONS, + AccuracyReward, + FormatReward, + RewardFn, +) __all__ = [ "RewardFn", diff --git a/trinity/common/rewards/dapo_reward.py b/trinity/common/rewards/dapo_reward.py index 3ebe87b92f..19b422cf90 100644 --- a/trinity/common/rewards/dapo_reward.py +++ b/trinity/common/rewards/dapo_reward.py @@ -3,7 +3,7 @@ import torch -from trinity.common.rewards import REWARD_FUNCTIONS, RewardFn +from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn from trinity.utils.eval_utils import compute_score from trinity.utils.log import get_logger From 889b36c595e9cd70ddd569b83c0134b021e6990b Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Mon, 14 Jul 2025 10:49:10 +0800 Subject: [PATCH 03/10] add base model mode for dapo_workflow --- examples/dapo_math/README.md | 5 ++++ examples/dapo_math/dapo.yaml | 9 ++++-- .../workflows/customized_math_workflows.py | 15 +++------- trinity/common/workflows/dapo_workflow.py | 30 ++++++++++--------- trinity/common/workflows/workflow.py | 11 +++++++ 5 files changed, 42 insertions(+), 28 deletions(-) create mode 100644 examples/dapo_math/README.md diff --git a/examples/dapo_math/README.md b/examples/dapo_math/README.md new file mode 100644 index 0000000000..28f8f3c625 --- /dev/null +++ b/examples/dapo_math/README.md @@ -0,0 +1,5 @@ +# DAPO on DAPO-MATH-17k dataset [WIP] + +This example shows the usage of DAPO on the [DAPO-MATH-17k](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed) dataset. + +The config files are located in [`dapo.yaml`](dapo.yaml) and [`train_dapo.yaml`](train_dapo.yaml). diff --git a/examples/dapo_math/dapo.yaml b/examples/dapo_math/dapo.yaml index dee8e5461f..d57a5b796b 100644 --- a/examples/dapo_math/dapo.yaml +++ b/examples/dapo_math/dapo.yaml @@ -21,14 +21,17 @@ buffer: taskset: name: dapo-math storage_type: file - path: /PATH/TO/Processed-DAPO-Math-17k/ + path: open-r1/DAPO-Math-17k-Processed + subset: all format: prompt_key: 'prompt' - response_key: 'ground_truth' + response_key: 'solution' + system_prompt: 'Solve the following math problem step by step. The last line of your response should be of the form Answer: $Answer (without quotes) where $Answer is the answer to the problem.' rollout_args: temperature: 1.0 logprobs: 0 workflow_args: + use_base: true enable_overlong_penalty: true penalty_factor: 1.0 max_response_length: 20480 @@ -45,7 +48,7 @@ buffer: n: 32 temperature: 1.0 top_p: 0.7 - default_workflow_type: 'math_dapo_workflow' + default_workflow_type: 'math_boxed_workflow' trainer_input: experience_buffer: name: math_buffer diff --git a/trinity/common/workflows/customized_math_workflows.py b/trinity/common/workflows/customized_math_workflows.py index d71a5d2fb1..f8727869e9 100644 --- a/trinity/common/workflows/customized_math_workflows.py +++ b/trinity/common/workflows/customized_math_workflows.py @@ -48,19 +48,12 @@ def reset(self, task: Task): else: self.system_prompt = default_prompt - self.reward_fn = MathBoxedRewardFn() - - def format_prompt(self): - prompt_text = "" - if self.system_prompt: - prompt_text += "System:" + self.system_prompt - prompt_text += "\nUser:\n" + self.task_desc + "\nAssistant:\n" + if task.reward_fn is None: + self.reward_fn = MathBoxedRewardFn() else: - prompt_text += "User:\n" + self.task_desc + "\nAssistant:\n" - return prompt_text + self.reward_fn = task.reward_fn def run(self) -> List[Experience]: - # TODO: Optimize the generate function if not self.use_base: messages = self.format_messages() else: @@ -73,7 +66,7 @@ def run(self) -> List[Experience]: responses = self.model.generate([prompt_text], **self.rollout_args) for response in responses: - reward = MathBoxedRewardFn()( # type: ignore [misc] + reward = self.reward_fn( # type: ignore [misc] response=response.response_text, # type: ignore [arg-type] truth=self.truth, return_dict=self.is_eval, diff --git a/trinity/common/workflows/dapo_workflow.py b/trinity/common/workflows/dapo_workflow.py index 3a64af2a2c..410ffee393 100644 --- a/trinity/common/workflows/dapo_workflow.py +++ b/trinity/common/workflows/dapo_workflow.py @@ -28,6 +28,7 @@ def reset(self, task: Task): self.is_eval = task.is_eval self.workflow_args = task.workflow_args + self.use_base = self.workflow_args.get("use_base", False) self.reward_fn = MathDAPORewardFn( enable_overlong_penalty=self.workflow_args.get("enable_overlong_penalty", None), @@ -37,29 +38,30 @@ def reset(self, task: Task): ) def run(self) -> List[Experience]: - messages = self.format_messages() + if not self.use_base: + messages = self.format_messages() + else: + prompt_text = self.format_prompt() logger.debug("start chat") - responses = self.model.chat(messages, **self.rollout_args) + if not self.use_base: + responses = self.model.chat(messages, **self.rollout_args) + else: + responses = self.model.generate([prompt_text], **self.rollout_args) for response in responses: - reward = self.reward_fn( # type: ignore # TODO: fix type + reward_dict = self.reward_fn( # type: ignore response=response.response_text, # type: ignore [arg-type] truth=self.truth, - return_dict=self.is_eval, response_token=response.tokens[response.prompt_length :], ) + if response.metrics is None: + response.metrics = {} + response.metrics.update(reward_dict) + reward = sum(reward_dict.values()) + response.reward = reward + logger.debug( f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" ) - if isinstance(reward, dict): - if response.metrics is None: - response.metrics = {} - response.metrics.update(reward) - reward = sum(reward.values()) - response.reward = reward return responses - - def format_messages(self): - messages = [{"role": "user", "content": self.task_desc}] - return messages diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 2bd0038435..03339f5ec3 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -194,6 +194,7 @@ def reset(self, task: Task): self.is_eval = task.is_eval def format_messages(self): + """Format messages for the instruct model.""" messages = [] if self.system_prompt: messages.append({"role": "system", "content": self.system_prompt}) @@ -202,6 +203,16 @@ def format_messages(self): messages.append({"role": "assistant", "content": self.reply_prefix}) return messages + def format_prompt(self): + """Format prompt for the base model.""" + prompt_text = "" + if self.system_prompt: + prompt_text += "System:" + self.system_prompt + prompt_text += "\nUser:\n" + self.task_desc + "\nAssistant:\n" + else: + prompt_text += "User:\n" + self.task_desc + "\nAssistant:\n" + return prompt_text + def run(self) -> List[Experience]: # TODO: Optimize the generate function messages = self.format_messages() From 880c5b0249e470cc91a81bd3cf141d0bf1a7f017 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Mon, 14 Jul 2025 16:53:42 +0800 Subject: [PATCH 04/10] fix dapo config --- examples/dapo_math/dapo.yaml | 3 ++- examples/dapo_math/train_dapo.yaml | 2 +- trinity/common/rewards/dapo_reward.py | 12 ++++-------- trinity/common/workflows/dapo_workflow.py | 2 ++ 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/examples/dapo_math/dapo.yaml b/examples/dapo_math/dapo.yaml index d57a5b796b..aba504c347 100644 --- a/examples/dapo_math/dapo.yaml +++ b/examples/dapo_math/dapo.yaml @@ -44,11 +44,12 @@ buffer: format: prompt_key: 'question' response_key: 'answer' + system_prompt: 'Solve the following math problem step by step. The last line of your response should be of the form Answer: $Answer (without quotes) where $Answer is the answer to the problem.' rollout_args: n: 32 temperature: 1.0 top_p: 0.7 - default_workflow_type: 'math_boxed_workflow' + default_workflow_type: 'math_dapo_workflow' trainer_input: experience_buffer: name: math_buffer diff --git a/examples/dapo_math/train_dapo.yaml b/examples/dapo_math/train_dapo.yaml index 32831b03fe..7bba8612ab 100644 --- a/examples/dapo_math/train_dapo.yaml +++ b/examples/dapo_math/train_dapo.yaml @@ -9,7 +9,7 @@ actor_rollout_ref: strategy: fsdp # This is for backward-compatibility ppo_micro_batch_size_per_gpu: 4 use_dynamic_bsz: True # False - ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} + ppo_max_token_len_per_gpu: 22000 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 ppo_epochs: 1 shuffle: False diff --git a/trinity/common/rewards/dapo_reward.py b/trinity/common/rewards/dapo_reward.py index 19b422cf90..3b5b4657f7 100644 --- a/trinity/common/rewards/dapo_reward.py +++ b/trinity/common/rewards/dapo_reward.py @@ -31,7 +31,6 @@ def __call__( # type: ignore response: str, response_token: torch.Tensor, truth: Optional[str] = None, - return_dict: Optional[bool] = False, ) -> Union[float, dict]: accuracy_score = compute_score(response, truth) @@ -40,13 +39,10 @@ def __call__( # type: ignore if self.enable_overlong_penalty: format_score = self.compute_overlong_penalty(response_token) - if return_dict: - return { - "accuracy": accuracy_score, - "format_score": format_score, - } - - return accuracy_score + format_score + return { + "accuracy": accuracy_score, + "format_score": format_score, + } def compute_overlong_penalty(self, response_token): assert ( diff --git a/trinity/common/workflows/dapo_workflow.py b/trinity/common/workflows/dapo_workflow.py index 410ffee393..b3cf5bdea5 100644 --- a/trinity/common/workflows/dapo_workflow.py +++ b/trinity/common/workflows/dapo_workflow.py @@ -18,6 +18,8 @@ class MathDAPOWorkflow(SimpleWorkflow): def reset(self, task: Task): self.format_args = task.format_args + self.system_prompt = task.format_args.system_prompt + self.reply_prefix = task.format_args.reply_prefix self.raw_task = task.raw_task self.task_desc = task.task_desc From c314a11790d51f9260dd3d302d982af6457e3efa Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Mon, 14 Jul 2025 19:53:55 +0800 Subject: [PATCH 05/10] fix typo --- examples/dapo_math/dapo.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dapo_math/dapo.yaml b/examples/dapo_math/dapo.yaml index aba504c347..fb847bc9f9 100644 --- a/examples/dapo_math/dapo.yaml +++ b/examples/dapo_math/dapo.yaml @@ -22,7 +22,7 @@ buffer: name: dapo-math storage_type: file path: open-r1/DAPO-Math-17k-Processed - subset: all + subset_name: all format: prompt_key: 'prompt' response_key: 'solution' From e1c76220b699f26895b4948f5c9b3e167b4db655 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Tue, 15 Jul 2025 10:47:23 +0800 Subject: [PATCH 06/10] fix comments --- trinity/common/rewards/dapo_reward.py | 3 ++- trinity/common/workflows/dapo_workflow.py | 6 +++++- trinity/common/workflows/workflow.py | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/trinity/common/rewards/dapo_reward.py b/trinity/common/rewards/dapo_reward.py index 3b5b4657f7..8263bf4d59 100644 --- a/trinity/common/rewards/dapo_reward.py +++ b/trinity/common/rewards/dapo_reward.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +"""Reward Function with Overlong Reward Shaping described in DAPO (https://arxiv.org/pdf/2503.14476)""" from typing import Optional, Union import torch @@ -60,6 +61,6 @@ def compute_overlong_penalty(self, response_token): if response_len < excepted_len: return 0.0 elif response_len > self.max_response_length: - return -1 + return -self.penalty_factor else: return (excepted_len - response_len) / self.cache_length * self.penalty_factor diff --git a/trinity/common/workflows/dapo_workflow.py b/trinity/common/workflows/dapo_workflow.py index b3cf5bdea5..abd4b5b29d 100644 --- a/trinity/common/workflows/dapo_workflow.py +++ b/trinity/common/workflows/dapo_workflow.py @@ -1,5 +1,9 @@ # -*- coding: utf-8 -*- -"""We include the DAPO math workflows in this file.""" +""" +We include the DAPO math workflows in this file. +This workflow adopts MathDAPORewardFn as the reward function. +Ref: https://arxiv.org/pdf/2503.14476 +""" from dataclasses import asdict from typing import List diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 657aaf8721..d95146941f 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -210,7 +210,7 @@ def format_prompt(self): """Format prompt for the base model.""" prompt_text = "" if self.system_prompt: - prompt_text += "System:" + self.system_prompt + prompt_text += "System:\n" + self.system_prompt prompt_text += "\nUser:\n" + self.task_desc + "\nAssistant:\n" else: prompt_text += "User:\n" + self.task_desc + "\nAssistant:\n" From 7621887842aa170120193a9d22e342d1866a6361 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Wed, 16 Jul 2025 13:57:40 +0800 Subject: [PATCH 07/10] mv dapo workflow to boxed workflo --- examples/dapo_math/dapo.yaml | 4 +- trinity/common/rewards/__init__.py | 1 + trinity/common/rewards/dapo_reward.py | 1 + trinity/common/rewards/math_reward.py | 3 +- trinity/common/workflows/__init__.py | 2 - .../workflows/customized_math_workflows.py | 26 +++++-- trinity/common/workflows/dapo_workflow.py | 73 ------------------- trinity/common/workflows/workflow.py | 10 --- 8 files changed, 28 insertions(+), 92 deletions(-) delete mode 100644 trinity/common/workflows/dapo_workflow.py diff --git a/examples/dapo_math/dapo.yaml b/examples/dapo_math/dapo.yaml index fb847bc9f9..a13375500b 100644 --- a/examples/dapo_math/dapo.yaml +++ b/examples/dapo_math/dapo.yaml @@ -32,6 +32,7 @@ buffer: logprobs: 0 workflow_args: use_base: true + reward_fn_args: enable_overlong_penalty: true penalty_factor: 1.0 max_response_length: 20480 @@ -49,7 +50,8 @@ buffer: n: 32 temperature: 1.0 top_p: 0.7 - default_workflow_type: 'math_dapo_workflow' + default_workflow_type: 'math_boxed_workflow' + default_reward_fn_type: 'math_dapo_reward' trainer_input: experience_buffer: name: math_buffer diff --git a/trinity/common/rewards/__init__.py b/trinity/common/rewards/__init__.py index a5f63a12d0..c723788908 100644 --- a/trinity/common/rewards/__init__.py +++ b/trinity/common/rewards/__init__.py @@ -6,6 +6,7 @@ from .accuracy_reward import AccuracyReward from .countdown_reward import CountDownRewardFn +from .dapo_reward import MathDAPORewardFn from .format_reward import FormatReward from .math_reward import MathBoxedRewardFn, MathRewardFn diff --git a/trinity/common/rewards/dapo_reward.py b/trinity/common/rewards/dapo_reward.py index 8263bf4d59..15f1f2bd67 100644 --- a/trinity/common/rewards/dapo_reward.py +++ b/trinity/common/rewards/dapo_reward.py @@ -32,6 +32,7 @@ def __call__( # type: ignore response: str, response_token: torch.Tensor, truth: Optional[str] = None, + **kwargs, ) -> Union[float, dict]: accuracy_score = compute_score(response, truth) diff --git a/trinity/common/rewards/math_reward.py b/trinity/common/rewards/math_reward.py index 09a5cd7428..9e1c8abed6 100644 --- a/trinity/common/rewards/math_reward.py +++ b/trinity/common/rewards/math_reward.py @@ -49,16 +49,17 @@ class MathBoxedRewardFn(RewardFn): def __init__( self, + **kwargs, ) -> None: pass def __call__( # type: ignore self, response: str, - prompt: Optional[str] = None, truth: Optional[str] = None, with_think: Optional[bool] = False, format_score_coef: Optional[float] = 0.1, + **kwargs, ) -> dict[str, float]: accuracy_score = compute_score(response, truth) diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py index 4906d50f5e..496996a05d 100644 --- a/trinity/common/workflows/__init__.py +++ b/trinity/common/workflows/__init__.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- """Workflow module""" from .customized_math_workflows import MathBoxedWorkflow -from .dapo_workflow import MathDAPOWorkflow from .envs.alfworld.alfworld_workflow import AlfworldWorkflow from .envs.sciworld.sciworld_workflow import SciWorldWorkflow from .envs.webshop.webshop_workflow import WebShopWorkflow @@ -19,5 +18,4 @@ "SciWorldWorkflow", "MathBoxedWorkflow", "MathRMWorkflow", - "MathDAPOWorkflow", ] diff --git a/trinity/common/workflows/customized_math_workflows.py b/trinity/common/workflows/customized_math_workflows.py index 22760faf74..c2762ae43c 100644 --- a/trinity/common/workflows/customized_math_workflows.py +++ b/trinity/common/workflows/customized_math_workflows.py @@ -31,6 +31,7 @@ def reset(self, task: Task): self.is_eval = task.is_eval self.workflow_args = task.workflow_args + self.reward_fn_args = task.reward_fn_args self.use_base = self.workflow_args.get("use_base", False) self.with_think = self.workflow_args.get("with_think", False) @@ -49,9 +50,18 @@ def reset(self, task: Task): self.system_prompt = default_prompt if task.reward_fn is None: - self.reward_fn = MathBoxedRewardFn() + self.reward_fn = MathBoxedRewardFn(**self.reward_fn_args) else: - self.reward_fn = task.reward_fn + self.reward_fn = task.reward_fn(**self.reward_fn_args) + + def format_prompt(self): + prompt_text = "" + if self.system_prompt: + prompt_text += "System:" + self.system_prompt + prompt_text += "\nUser:\n" + self.task_desc + "\nAssistant:\n" + else: + prompt_text += "User:\n" + self.task_desc + "\nAssistant:\n" + return prompt_text def run(self) -> List[Experience]: if not self.use_base: @@ -71,6 +81,7 @@ def run(self) -> List[Experience]: truth=self.truth, with_think=self.with_think, format_score_coef=self.format_score_coef, + response_token=response.tokens[response.prompt_length :], ) if response.metrics is None: @@ -79,7 +90,12 @@ def run(self) -> List[Experience]: reward = sum(reward_dict.values()) response.reward = reward - logger.debug( - f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" - ) + if not self.use_base: + logger.debug( + f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" + ) + else: + logger.debug( + f"self.task_desc: {self.task_desc}, prompt_text: {prompt_text}, response: {response.response_text}, reward: {reward}" + ) return responses diff --git a/trinity/common/workflows/dapo_workflow.py b/trinity/common/workflows/dapo_workflow.py deleted file mode 100644 index abd4b5b29d..0000000000 --- a/trinity/common/workflows/dapo_workflow.py +++ /dev/null @@ -1,73 +0,0 @@ -# -*- coding: utf-8 -*- -""" -We include the DAPO math workflows in this file. -This workflow adopts MathDAPORewardFn as the reward function. -Ref: https://arxiv.org/pdf/2503.14476 -""" - -from dataclasses import asdict -from typing import List - -from trinity.common.experience import Experience -from trinity.common.rewards.dapo_reward import MathDAPORewardFn -from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task -from trinity.utils.log import get_logger - -logger = get_logger(__name__) - - -@WORKFLOWS.register_module("math_dapo_workflow") -class MathDAPOWorkflow(SimpleWorkflow): - """A workflow for math tasks as introduced in DAPO.""" - - def reset(self, task: Task): - self.format_args = task.format_args - self.system_prompt = task.format_args.system_prompt - self.reply_prefix = task.format_args.reply_prefix - - self.raw_task = task.raw_task - self.task_desc = task.task_desc - self.truth = task.truth - - rollout_args = asdict(task.rollout_args) - self.rollout_args = rollout_args - self.is_eval = task.is_eval - - self.workflow_args = task.workflow_args - self.use_base = self.workflow_args.get("use_base", False) - - self.reward_fn = MathDAPORewardFn( - enable_overlong_penalty=self.workflow_args.get("enable_overlong_penalty", None), - penalty_factor=self.workflow_args.get("penalty_factor", None), - max_response_length=self.workflow_args.get("max_response_length", None), - cache_length=self.workflow_args.get("cache_length", None), - ) - - def run(self) -> List[Experience]: - if not self.use_base: - messages = self.format_messages() - else: - prompt_text = self.format_prompt() - - logger.debug("start chat") - if not self.use_base: - responses = self.model.chat(messages, **self.rollout_args) - else: - responses = self.model.generate([prompt_text], **self.rollout_args) - - for response in responses: - reward_dict = self.reward_fn( # type: ignore - response=response.response_text, # type: ignore [arg-type] - truth=self.truth, - response_token=response.tokens[response.prompt_length :], - ) - if response.metrics is None: - response.metrics = {} - response.metrics.update(reward_dict) - reward = sum(reward_dict.values()) - response.reward = reward - - logger.debug( - f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" - ) - return responses diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index d95146941f..e9549d9e2e 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -206,16 +206,6 @@ def format_messages(self): messages.append({"role": "assistant", "content": self.reply_prefix}) return messages - def format_prompt(self): - """Format prompt for the base model.""" - prompt_text = "" - if self.system_prompt: - prompt_text += "System:\n" + self.system_prompt - prompt_text += "\nUser:\n" + self.task_desc + "\nAssistant:\n" - else: - prompt_text += "User:\n" + self.task_desc + "\nAssistant:\n" - return prompt_text - def run(self) -> List[Experience]: # TODO: Optimize the generate function messages = self.format_messages() From ca04b7ae9a7010c8154ef4a20821827d84f5a7ed Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Wed, 16 Jul 2025 17:08:18 +0800 Subject: [PATCH 08/10] fix unittest --- tests/explorer/workflow_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index 20a8c5064c..2132a591a0 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -5,6 +5,8 @@ from typing import Dict, Optional from unittest.mock import MagicMock +from torch import Tensor + from tests.tools import get_unittest_dataset_config from trinity.common.rewards import RMGalleryFn from trinity.common.workflows import ( @@ -23,6 +25,8 @@ class MockResponse: metrics: Optional[Dict[str, float]] = None info: Optional[Dict] = None unique_id: Optional[str] = "0" + tokens: Optional[Tensor] = Tensor([0, 0]) + prompt_length: int = 1 class DummyWorkflow(Workflow): From 7d962139aac05cfd3e433d16f10168532be6da2c Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Wed, 16 Jul 2025 19:24:42 +0800 Subject: [PATCH 09/10] fix typo --- trinity/common/rewards/dapo_reward.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trinity/common/rewards/dapo_reward.py b/trinity/common/rewards/dapo_reward.py index 15f1f2bd67..703a8c2bc0 100644 --- a/trinity/common/rewards/dapo_reward.py +++ b/trinity/common/rewards/dapo_reward.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """Reward Function with Overlong Reward Shaping described in DAPO (https://arxiv.org/pdf/2503.14476)""" -from typing import Optional, Union +from typing import Dict, Optional import torch @@ -33,7 +33,7 @@ def __call__( # type: ignore response_token: torch.Tensor, truth: Optional[str] = None, **kwargs, - ) -> Union[float, dict]: + ) -> Dict[str, float]: accuracy_score = compute_score(response, truth) format_score = 0.0 From 3239c7379a29b7d89dc143bf7b63e36760eb259e Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Wed, 16 Jul 2025 19:27:06 +0800 Subject: [PATCH 10/10] fix typo --- trinity/common/rewards/dapo_reward.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trinity/common/rewards/dapo_reward.py b/trinity/common/rewards/dapo_reward.py index 703a8c2bc0..a527bf613a 100644 --- a/trinity/common/rewards/dapo_reward.py +++ b/trinity/common/rewards/dapo_reward.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """Reward Function with Overlong Reward Shaping described in DAPO (https://arxiv.org/pdf/2503.14476)""" -from typing import Dict, Optional +from typing import Optional import torch @@ -33,7 +33,7 @@ def __call__( # type: ignore response_token: torch.Tensor, truth: Optional[str] = None, **kwargs, - ) -> Dict[str, float]: + ) -> dict[str, float]: accuracy_score = compute_score(response, truth) format_score = 0.0