Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/sphinx_doc/source/tutorial/example_react.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ The `algorithm` section configures the training algorithm for the agent applicat

```yaml
algorithm:
algorithm_type: grpo
advantage_fn: step_wise_grpo # The key for multi-step training. This strategy tells Trinity to create independent training samples for each step in the agent's execution path. The `grpo` algorithm then uses these samples to update the model.
algorithm_type: multi_step_grpo # Specify multi-step GRPO training algorithm
```

#### Dynamic Synchronization Configuration
Expand Down
5 changes: 2 additions & 3 deletions docs/sphinx_doc/source/tutorial/example_step_wise.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ and include it in the init file `trinity/common/workflows/__init__.py`

In general multi-step scenarios, each run may generate various number of experiences. To accomodate this case, we provide some flexible designs.

- `algorithm.advantage_fn = step_wise_grpo`: This function allows you compute the advantages for the collected experience before adding to the buffer. For this example, we use `step_wise_grpo` which broadcasts advantages from the last step to previous steps.
- `algorithm.algorithm_type = multi_step_grpo`: This algorithm allows you to have multiple steps in each run and generate multiple experiences for training, and it broadcasts the advantages value of the last step experience to the previous experiences.

- `buffer.train_batch_size`: The number of experiences to be sampled from the buffer for training, which can be different from the number of generated experiences in each explore step.

Expand All @@ -93,9 +93,8 @@ project: "ALFWORLD"
name: "Step_Wise_Alfworld"
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
algorithm:
algorithm_type: grpo
algorithm_type: multi_step_grpo
repeat_times: 16
advantage_fn: step_wise_grpo
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
max_response_tokens: 16384
Expand Down
3 changes: 1 addition & 2 deletions docs/sphinx_doc/source_zh/tutorial/example_react.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,7 @@ explorer:

```yaml
algorithm:
algorithm_type: grpo
advantage_fn: step_wise_grpo # 多步训练的关键,该策略告诉 Trinity 为智能体执行路径中的每一步创建独立的训练样本。`grpo` 算法随后使用这些样本来更新模型。
algorithm_type: multi_step_grpo # 指定多步 GRPO 训练算法
```

#### 动态同步配置
Expand Down
9 changes: 4 additions & 5 deletions docs/sphinx_doc/source_zh/tutorial/example_step_wise.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

接下来我们将以 ALFWorld 为例说明通用多步工作流。如需动手实践,可直接跳转至 [代码实现](#example-multi-step-alfworld)。

## 构建通用的逐 步工作流
## 构建通用的多步工作流

### 基本概念

在 Trinity 中,我们提供了两种通用的逐 步工作流类型:`StepWiseRewardWorkflow` 和 `RewardPropagationWorkflow`。这些工作流设定了每一步的基本结构,并在每次运行时返回一个 `experiences` 列表。它们的区别在于:`StepWiseRewardWorkflow` 为每一步单独计算奖励,而 `RewardPropagationWorkflow` 在所有步骤结束后计算奖励,并将奖励反向传播到之前的步骤。更多细节请参见 `trinity/common/workflows/step_wise_workflow.py`。
在 Trinity 中,我们提供了两种通用的多步工作流类型:`StepWiseRewardWorkflow` 和 `RewardPropagationWorkflow`。这些工作流设定了每一步的基本结构,并在每次运行时返回一个 `experiences` 列表。它们的区别在于:`StepWiseRewardWorkflow` 为每一步单独计算奖励,而 `RewardPropagationWorkflow` 在所有步骤结束后计算奖励,并将奖励反向传播到之前的步骤。更多细节请参见 `trinity/common/workflows/step_wise_workflow.py`。

要构建一个新的工作流,你主要需要定义 `step()` 中的每一步交互逻辑,以及 `reward()` 中的奖励函数。例如,ALFWorld 工作流的核心代码如下所示:

Expand Down Expand Up @@ -76,7 +76,7 @@ class StepWiseAlfworldWorkflow(RewardPropagationWorkflow):

在通用多步场景中,每次运行可能会生成不同数量的 experience。为了适应这种情况,我们提供了一些灵活的设计。

- `algorithm.advantage_fn = step_wise_grpo`:该函数允许你在将 experience 加入 buffer 前计算其 Advantage。在此示例中,我们使用 `step_wise_grpo`,它会将最后一步的 Advantage 广播到前面各步 experience。
- `algorithm.algorithm_type = multi_step_grpo`:该算法允许每次运行包含多个步骤并生成多条 experience 数据用于训练,并将最后一步 experience 的 advantages 值广播到之前的 experience

- `buffer.train_batch_size`:从 buffer 中采样用于训练的 experience 数量,可以与每次探索生成的 experience 数量不同。

Expand All @@ -91,9 +91,8 @@ project: "ALFWORLD"
name: "Step_Wise_Alfworld"
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
algorithm:
algorithm_type: grpo
algorithm_type: multi_step_grpo
repeat_times: 16
advantage_fn: step_wise_grpo
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
max_response_tokens: 16384
Expand Down
3 changes: 1 addition & 2 deletions examples/agentscope_react/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ project: AgentScope-ReAct
name: GSM8K-Qwen3-8B
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
algorithm:
algorithm_type: grpo
algorithm_type: multi_step_grpo
repeat_times: 8
advantage_fn: step_wise_grpo
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-8B}
max_response_tokens: 16384
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ project: "Trinity-RFT-dapo-reactv2"
name: "Qwen3-8B-dapo-reactv2"
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
algorithm:
algorithm_type: grpo
algorithm_type: multi_step_grpo
repeat_times: 8
advantage_fn: step_wise_grpo
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-8B}
max_response_tokens: 16384
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ project: "Trinity-RFT-gsm8k-reactv2"
name: "Qwen3-4B-gsm8k-reactv2"
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
algorithm:
algorithm_type: grpo
algorithm_type: multi_step_grpo
repeat_times: 8
advantage_fn: step_wise_grpo
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-4B}
max_response_tokens: 16384
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ project: "Trinity-RFT-dapo-react"
name: "Qwen3-4B-dapo-react"
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
algorithm:
algorithm_type: grpo
algorithm_type: multi_step_grpo
repeat_times: 8
advantage_fn: step_wise_grpo
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-4B-Instruct-2507}
max_response_tokens: 16384
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ project: "Trinity_Multi_Step"
name: WebQA_Search_Example
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
algorithm:
algorithm_type: grpo
algorithm_type: multi_step_grpo
repeat_times: 8
advantage_fn: step_wise_grpo
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
max_response_tokens: 4096
Expand Down
3 changes: 1 addition & 2 deletions examples/grpo_alfworld_general_multi_step/alfworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ project: "ALFWORLD"
name: "Step_Wise_Alfworld"
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
algorithm:
algorithm_type: grpo
algorithm_type: multi_step_grpo
repeat_times: 16
advantage_fn: step_wise_grpo
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
max_response_tokens: 16384
Expand Down
3 changes: 1 addition & 2 deletions examples/grpo_email_search/email_search.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ project: "Trinity_Multi_Step"
name: "Email_Example"
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
algorithm:
algorithm_type: grpo
algorithm_type: multi_step_grpo
repeat_times: 8
advantage_fn: grpo
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-4B-Instruct-2507}
max_response_tokens: 4096
Expand Down
41 changes: 32 additions & 9 deletions trinity/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class PPOAlgorithm(AlgorithmType):
def default_config(cls) -> Dict:
return {
"repeat_times": 1,
"sample_strategy": "warmup",
"sample_strategy": "default",
"policy_loss_fn": "ppo",
"advantage_fn": "ppo",
"kl_penalty_fn": "none",
Expand All @@ -106,7 +106,7 @@ def default_config(cls) -> Dict:
return {
"repeat_times": 2,
"advantage_fn": "grpo",
"sample_strategy": "warmup",
"sample_strategy": "default",
"policy_loss_fn": "ppo",
"kl_penalty_fn": "none",
"kl_loss_fn": "k2",
Expand All @@ -129,7 +129,7 @@ def default_config(cls) -> Dict:
return {
"repeat_times": 2,
"advantage_fn": "opmd",
"sample_strategy": "warmup",
"sample_strategy": "default",
"policy_loss_fn": "opmd",
"kl_penalty_fn": "none",
"kl_loss_fn": "k2",
Expand All @@ -151,7 +151,7 @@ class AsymREAlgorithm(AlgorithmType):
def default_config(cls) -> Dict:
return {
"repeat_times": 2,
"sample_strategy": "warmup",
"sample_strategy": "default",
"policy_loss_fn": "opmd",
"advantage_fn": "asymre",
"kl_penalty_fn": "none",
Expand All @@ -173,7 +173,7 @@ class DPOAlgorithm(AlgorithmType):
@classmethod
def default_config(cls) -> Dict:
return {
"sample_strategy": "warmup",
"sample_strategy": "default",
"policy_loss_fn": "dpo",
"kl_loss_fn": "k2",
"entropy_loss_fn": "default",
Expand Down Expand Up @@ -219,7 +219,7 @@ def default_config(cls) -> Dict:
return {
"repeat_times": 2,
"advantage_fn": "reinforce", # or simply use grpo
"sample_strategy": "warmup",
"sample_strategy": "default",
"policy_loss_fn": "topr",
"kl_penalty_fn": "none",
"kl_loss_fn": "k2",
Expand All @@ -242,7 +242,7 @@ def default_config(cls) -> Dict:
return {
"repeat_times": 2,
"advantage_fn": "grpo",
"sample_strategy": "warmup",
"sample_strategy": "default",
"policy_loss_fn": "cispo",
"kl_penalty_fn": "none",
"kl_loss_fn": "k2",
Expand Down Expand Up @@ -331,7 +331,7 @@ class sPPOAlgorithm(AlgorithmType):
def default_config(cls) -> Dict:
return {
"repeat_times": 2,
"sample_strategy": "warmup",
"sample_strategy": "default",
"policy_loss_fn": "sppo",
"advantage_fn": "opmd",
"kl_penalty_fn": "none",
Expand All @@ -354,10 +354,33 @@ class RECAlgorithm(AlgorithmType):
def default_config(cls) -> Dict:
return {
"repeat_times": 2,
"sample_strategy": "warmup",
"sample_strategy": "default",
"policy_loss_fn": "rec",
"advantage_fn": "rec",
"kl_penalty_fn": "none",
"kl_loss_fn": "none",
"entropy_loss_fn": "none",
}


@ALGORITHM_TYPE.register_module("multi_step_grpo")
class MultiStepGRPOAlgorithm(AlgorithmType):
"""Multi-Step GRPO Algorithm."""

use_critic: bool = False
use_reference: bool = True
compute_advantage_in_trainer: bool = False
can_balance_batch: bool = True
schema: str = "experience"

@classmethod
def default_config(cls) -> Dict:
return {
"repeat_times": 8,
"advantage_fn": "step_wise_grpo",
"sample_strategy": "default",
"policy_loss_fn": "ppo",
"kl_penalty_fn": "none",
"kl_loss_fn": "k2",
"entropy_loss_fn": "default",
}
4 changes: 1 addition & 3 deletions trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from omegaconf import OmegaConf

from trinity.common.config import BufferConfig, Config, SynchronizerConfig
from trinity.common.config import Config, SynchronizerConfig
from trinity.common.constants import EXPLORER_NAME
from trinity.utils.log import get_logger

Expand Down Expand Up @@ -338,7 +338,6 @@ class veRLConfig:
custom_reward_function: CustomRewardFunction = field(default_factory=CustomRewardFunction)
algorithm: Algorithm = field(default_factory=Algorithm)
trainer: Trainer = field(default_factory=Trainer)
buffer: BufferConfig = field(default_factory=BufferConfig)
synchronizer: Optional[SynchronizerConfig] = None
enable_preview: bool = True

Expand Down Expand Up @@ -394,7 +393,6 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
else:
self.trainer.resume_mode = "auto"

self.buffer = config.buffer
self.data.train_batch_size = (
config.buffer.train_batch_size
) # kept to pass RayPPOTrainer._validate_config
Expand Down