diff --git a/docs/sphinx_doc/source/tutorial/example_react.md b/docs/sphinx_doc/source/tutorial/example_react.md index d43ff65d6e..4264696ab9 100644 --- a/docs/sphinx_doc/source/tutorial/example_react.md +++ b/docs/sphinx_doc/source/tutorial/example_react.md @@ -146,8 +146,7 @@ The `algorithm` section configures the training algorithm for the agent applicat ```yaml algorithm: - algorithm_type: grpo - advantage_fn: step_wise_grpo # The key for multi-step training. This strategy tells Trinity to create independent training samples for each step in the agent's execution path. The `grpo` algorithm then uses these samples to update the model. + algorithm_type: multi_step_grpo # Specify multi-step GRPO training algorithm ``` #### Dynamic Synchronization Configuration diff --git a/docs/sphinx_doc/source/tutorial/example_step_wise.md b/docs/sphinx_doc/source/tutorial/example_step_wise.md index 75900fac37..c408369693 100644 --- a/docs/sphinx_doc/source/tutorial/example_step_wise.md +++ b/docs/sphinx_doc/source/tutorial/example_step_wise.md @@ -77,7 +77,7 @@ and include it in the init file `trinity/common/workflows/__init__.py` In general multi-step scenarios, each run may generate various number of experiences. To accomodate this case, we provide some flexible designs. -- `algorithm.advantage_fn = step_wise_grpo`: This function allows you compute the advantages for the collected experience before adding to the buffer. For this example, we use `step_wise_grpo` which broadcasts advantages from the last step to previous steps. +- `algorithm.algorithm_type = multi_step_grpo`: This algorithm allows you to have multiple steps in each run and generate multiple experiences for training, and it broadcasts the advantages value of the last step experience to the previous experiences. - `buffer.train_batch_size`: The number of experiences to be sampled from the buffer for training, which can be different from the number of generated experiences in each explore step. @@ -93,9 +93,8 @@ project: "ALFWORLD" name: "Step_Wise_Alfworld" checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} algorithm: - algorithm_type: grpo + algorithm_type: multi_step_grpo repeat_times: 16 - advantage_fn: step_wise_grpo model: model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct} max_response_tokens: 16384 diff --git a/docs/sphinx_doc/source_zh/tutorial/example_react.md b/docs/sphinx_doc/source_zh/tutorial/example_react.md index 69d76d4784..ab6b6a1d73 100644 --- a/docs/sphinx_doc/source_zh/tutorial/example_react.md +++ b/docs/sphinx_doc/source_zh/tutorial/example_react.md @@ -153,8 +153,7 @@ explorer: ```yaml algorithm: - algorithm_type: grpo - advantage_fn: step_wise_grpo # 多步训练的关键,该策略告诉 Trinity 为智能体执行路径中的每一步创建独立的训练样本。`grpo` 算法随后使用这些样本来更新模型。 + algorithm_type: multi_step_grpo # 指定多步 GRPO 训练算法 ``` #### 动态同步配置 diff --git a/docs/sphinx_doc/source_zh/tutorial/example_step_wise.md b/docs/sphinx_doc/source_zh/tutorial/example_step_wise.md index 625aec69f3..76e1a3fcb7 100644 --- a/docs/sphinx_doc/source_zh/tutorial/example_step_wise.md +++ b/docs/sphinx_doc/source_zh/tutorial/example_step_wise.md @@ -6,11 +6,11 @@ 接下来我们将以 ALFWorld 为例说明通用多步工作流。如需动手实践,可直接跳转至 [代码实现](#example-multi-step-alfworld)。 -## 构建通用的逐 步工作流 +## 构建通用的多步工作流 ### 基本概念 -在 Trinity 中,我们提供了两种通用的逐 步工作流类型:`StepWiseRewardWorkflow` 和 `RewardPropagationWorkflow`。这些工作流设定了每一步的基本结构,并在每次运行时返回一个 `experiences` 列表。它们的区别在于:`StepWiseRewardWorkflow` 为每一步单独计算奖励,而 `RewardPropagationWorkflow` 在所有步骤结束后计算奖励,并将奖励反向传播到之前的步骤。更多细节请参见 `trinity/common/workflows/step_wise_workflow.py`。 +在 Trinity 中,我们提供了两种通用的多步工作流类型:`StepWiseRewardWorkflow` 和 `RewardPropagationWorkflow`。这些工作流设定了每一步的基本结构,并在每次运行时返回一个 `experiences` 列表。它们的区别在于:`StepWiseRewardWorkflow` 为每一步单独计算奖励,而 `RewardPropagationWorkflow` 在所有步骤结束后计算奖励,并将奖励反向传播到之前的步骤。更多细节请参见 `trinity/common/workflows/step_wise_workflow.py`。 要构建一个新的工作流,你主要需要定义 `step()` 中的每一步交互逻辑,以及 `reward()` 中的奖励函数。例如,ALFWorld 工作流的核心代码如下所示: @@ -76,7 +76,7 @@ class StepWiseAlfworldWorkflow(RewardPropagationWorkflow): 在通用多步场景中,每次运行可能会生成不同数量的 experience。为了适应这种情况,我们提供了一些灵活的设计。 -- `algorithm.advantage_fn = step_wise_grpo`:该函数允许你在将 experience 加入 buffer 前计算其 Advantage。在此示例中,我们使用 `step_wise_grpo`,它会将最后一步的 Advantage 广播到前面各步 experience。 +- `algorithm.algorithm_type = multi_step_grpo`:该算法允许每次运行包含多个步骤并生成多条 experience 数据用于训练,并将最后一步 experience 的 advantages 值广播到之前的 experience 中。 - `buffer.train_batch_size`:从 buffer 中采样用于训练的 experience 数量,可以与每次探索生成的 experience 数量不同。 @@ -91,9 +91,8 @@ project: "ALFWORLD" name: "Step_Wise_Alfworld" checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} algorithm: - algorithm_type: grpo + algorithm_type: multi_step_grpo repeat_times: 16 - advantage_fn: step_wise_grpo model: model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct} max_response_tokens: 16384 diff --git a/examples/agentscope_react/gsm8k.yaml b/examples/agentscope_react/gsm8k.yaml index b4f9a79952..e864a49b93 100644 --- a/examples/agentscope_react/gsm8k.yaml +++ b/examples/agentscope_react/gsm8k.yaml @@ -2,9 +2,8 @@ project: AgentScope-ReAct name: GSM8K-Qwen3-8B checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} algorithm: - algorithm_type: grpo + algorithm_type: multi_step_grpo repeat_times: 8 - advantage_fn: step_wise_grpo model: model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-8B} max_response_tokens: 16384 diff --git a/examples/agentscope_tool_react/agentscopev0_tool_react_dapo.yaml b/examples/agentscope_tool_react/agentscopev0_tool_react_dapo.yaml index 4c0d8c1b9d..5e353ee95d 100644 --- a/examples/agentscope_tool_react/agentscopev0_tool_react_dapo.yaml +++ b/examples/agentscope_tool_react/agentscopev0_tool_react_dapo.yaml @@ -2,9 +2,8 @@ project: "Trinity-RFT-dapo-reactv2" name: "Qwen3-8B-dapo-reactv2" checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} algorithm: - algorithm_type: grpo + algorithm_type: multi_step_grpo repeat_times: 8 - advantage_fn: step_wise_grpo model: model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-8B} max_response_tokens: 16384 diff --git a/examples/agentscope_tool_react/agentscopev0_tool_react_gsm8k.yaml b/examples/agentscope_tool_react/agentscopev0_tool_react_gsm8k.yaml index 57e37274f9..30cd13240d 100644 --- a/examples/agentscope_tool_react/agentscopev0_tool_react_gsm8k.yaml +++ b/examples/agentscope_tool_react/agentscopev0_tool_react_gsm8k.yaml @@ -2,9 +2,8 @@ project: "Trinity-RFT-gsm8k-reactv2" name: "Qwen3-4B-gsm8k-reactv2" checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} algorithm: - algorithm_type: grpo + algorithm_type: multi_step_grpo repeat_times: 8 - advantage_fn: step_wise_grpo model: model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-4B} max_response_tokens: 16384 diff --git a/examples/agentscope_tool_react/agentscopev1_tool_react_dapo.yaml b/examples/agentscope_tool_react/agentscopev1_tool_react_dapo.yaml index 2b10707c4b..b45e7b85a1 100644 --- a/examples/agentscope_tool_react/agentscopev1_tool_react_dapo.yaml +++ b/examples/agentscope_tool_react/agentscopev1_tool_react_dapo.yaml @@ -2,9 +2,8 @@ project: "Trinity-RFT-dapo-react" name: "Qwen3-4B-dapo-react" checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} algorithm: - algorithm_type: grpo + algorithm_type: multi_step_grpo repeat_times: 8 - advantage_fn: step_wise_grpo model: model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-4B-Instruct-2507} max_response_tokens: 16384 diff --git a/examples/agentscope_websearch/agentscopev1_websearch_agent.yaml b/examples/agentscope_websearch/agentscopev1_websearch_agent.yaml index 7d4835791b..b056181877 100644 --- a/examples/agentscope_websearch/agentscopev1_websearch_agent.yaml +++ b/examples/agentscope_websearch/agentscopev1_websearch_agent.yaml @@ -2,9 +2,8 @@ project: "Trinity_Multi_Step" name: WebQA_Search_Example checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} algorithm: - algorithm_type: grpo + algorithm_type: multi_step_grpo repeat_times: 8 - advantage_fn: step_wise_grpo model: model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct} max_response_tokens: 4096 diff --git a/examples/grpo_alfworld_general_multi_step/alfworld.yaml b/examples/grpo_alfworld_general_multi_step/alfworld.yaml index f93f028420..4e17779b6e 100644 --- a/examples/grpo_alfworld_general_multi_step/alfworld.yaml +++ b/examples/grpo_alfworld_general_multi_step/alfworld.yaml @@ -2,9 +2,8 @@ project: "ALFWORLD" name: "Step_Wise_Alfworld" checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} algorithm: - algorithm_type: grpo + algorithm_type: multi_step_grpo repeat_times: 16 - advantage_fn: step_wise_grpo model: model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct} max_response_tokens: 16384 diff --git a/examples/grpo_email_search/email_search.yaml b/examples/grpo_email_search/email_search.yaml index 67f0cb098e..0cc159244c 100644 --- a/examples/grpo_email_search/email_search.yaml +++ b/examples/grpo_email_search/email_search.yaml @@ -2,9 +2,8 @@ project: "Trinity_Multi_Step" name: "Email_Example" checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} algorithm: - algorithm_type: grpo + algorithm_type: multi_step_grpo repeat_times: 8 - advantage_fn: grpo model: model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-4B-Instruct-2507} max_response_tokens: 4096 diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index f9ae14e8c9..4384da7b8e 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -82,7 +82,7 @@ class PPOAlgorithm(AlgorithmType): def default_config(cls) -> Dict: return { "repeat_times": 1, - "sample_strategy": "warmup", + "sample_strategy": "default", "policy_loss_fn": "ppo", "advantage_fn": "ppo", "kl_penalty_fn": "none", @@ -106,7 +106,7 @@ def default_config(cls) -> Dict: return { "repeat_times": 2, "advantage_fn": "grpo", - "sample_strategy": "warmup", + "sample_strategy": "default", "policy_loss_fn": "ppo", "kl_penalty_fn": "none", "kl_loss_fn": "k2", @@ -129,7 +129,7 @@ def default_config(cls) -> Dict: return { "repeat_times": 2, "advantage_fn": "opmd", - "sample_strategy": "warmup", + "sample_strategy": "default", "policy_loss_fn": "opmd", "kl_penalty_fn": "none", "kl_loss_fn": "k2", @@ -151,7 +151,7 @@ class AsymREAlgorithm(AlgorithmType): def default_config(cls) -> Dict: return { "repeat_times": 2, - "sample_strategy": "warmup", + "sample_strategy": "default", "policy_loss_fn": "opmd", "advantage_fn": "asymre", "kl_penalty_fn": "none", @@ -173,7 +173,7 @@ class DPOAlgorithm(AlgorithmType): @classmethod def default_config(cls) -> Dict: return { - "sample_strategy": "warmup", + "sample_strategy": "default", "policy_loss_fn": "dpo", "kl_loss_fn": "k2", "entropy_loss_fn": "default", @@ -219,7 +219,7 @@ def default_config(cls) -> Dict: return { "repeat_times": 2, "advantage_fn": "reinforce", # or simply use grpo - "sample_strategy": "warmup", + "sample_strategy": "default", "policy_loss_fn": "topr", "kl_penalty_fn": "none", "kl_loss_fn": "k2", @@ -242,7 +242,7 @@ def default_config(cls) -> Dict: return { "repeat_times": 2, "advantage_fn": "grpo", - "sample_strategy": "warmup", + "sample_strategy": "default", "policy_loss_fn": "cispo", "kl_penalty_fn": "none", "kl_loss_fn": "k2", @@ -331,7 +331,7 @@ class sPPOAlgorithm(AlgorithmType): def default_config(cls) -> Dict: return { "repeat_times": 2, - "sample_strategy": "warmup", + "sample_strategy": "default", "policy_loss_fn": "sppo", "advantage_fn": "opmd", "kl_penalty_fn": "none", @@ -354,10 +354,33 @@ class RECAlgorithm(AlgorithmType): def default_config(cls) -> Dict: return { "repeat_times": 2, - "sample_strategy": "warmup", + "sample_strategy": "default", "policy_loss_fn": "rec", "advantage_fn": "rec", "kl_penalty_fn": "none", "kl_loss_fn": "none", "entropy_loss_fn": "none", } + + +@ALGORITHM_TYPE.register_module("multi_step_grpo") +class MultiStepGRPOAlgorithm(AlgorithmType): + """Multi-Step GRPO Algorithm.""" + + use_critic: bool = False + use_reference: bool = True + compute_advantage_in_trainer: bool = False + can_balance_batch: bool = True + schema: str = "experience" + + @classmethod + def default_config(cls) -> Dict: + return { + "repeat_times": 8, + "advantage_fn": "step_wise_grpo", + "sample_strategy": "default", + "policy_loss_fn": "ppo", + "kl_penalty_fn": "none", + "kl_loss_fn": "k2", + "entropy_loss_fn": "default", + } diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index eb360266b4..9d7031d083 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -5,7 +5,7 @@ from omegaconf import OmegaConf -from trinity.common.config import BufferConfig, Config, SynchronizerConfig +from trinity.common.config import Config, SynchronizerConfig from trinity.common.constants import EXPLORER_NAME from trinity.utils.log import get_logger @@ -338,7 +338,6 @@ class veRLConfig: custom_reward_function: CustomRewardFunction = field(default_factory=CustomRewardFunction) algorithm: Algorithm = field(default_factory=Algorithm) trainer: Trainer = field(default_factory=Trainer) - buffer: BufferConfig = field(default_factory=BufferConfig) synchronizer: Optional[SynchronizerConfig] = None enable_preview: bool = True @@ -394,7 +393,6 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 else: self.trainer.resume_mode = "auto" - self.buffer = config.buffer self.data.train_batch_size = ( config.buffer.train_batch_size ) # kept to pass RayPPOTrainer._validate_config