Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions docs/sphinx_doc/source/tutorial/develop_workflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,108 @@ class ExampleWorkflow(Workflow):
2. When calling `chat.completions.create`, the `model` field can be obtained via `openai_client.models.list().data[0].id` or `openai_client.model_path`.
3. For more complex workflow examples using the OpenAI API, refer to [ReAct Agent Training](./example_react.md).
```

#### LLM-as-a-judge Support

LLM-as-a-judge is a common reward calculation method, especially suitable for open-ended tasks (such as programming, writing, etc.). In these scenarios, the Workflow needs to leverage an additional LLM to evaluate the answer quality and compute the reward signal.

To support this, Trinity-RFT provides an Auxiliary Models mechanism. Auxiliary models are a set of models not involved in training; the Workflow can use these models to assist with tasks, such as acting as a judge to calculate rewards.

You can specify one or more auxiliary models in the configuration file via the `explorer.auxiliary_models` field. For example:

```yaml
explorer:
auxiliary_models:
- model_path: Qwen/Qwen2.5-32B-Instruct
engine_num: 1
tensor_parallel_size: 2
enable_thinking: false
max_prompt_tokens: 12288
max_response_tokens: 12288
max_model_len: 16384
- model_path: Qwen/Qwen3-8B
engine_num: 1
tensor_parallel_size: 2
enable_thinking: false
max_prompt_tokens: 12288
max_response_tokens: 12288
max_model_len: 16384
```

Note that each auxiliary model will independently occupy `tensor_parallel_size * engine_num` GPUs. Please configure according to your hardware resources. After enabling auxiliary models, the number of GPUs available to the Trainer is the total GPU count minus those occupied by all auxiliary models and the inference model being trained (`rollout_model`).

The auxiliary models specified in the configuration file will automatically activate the OpenAI API and pass the corresponding `openai.OpenAI` instances to the `auxiliary_models` parameter of the `Workflow` initialization method. For example:

```python
class MyWorkflow(Workflow):
def __init__(
self,
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.judge_model = self.auxiliary_models[0] # Use the first auxiliary model as the judge

def run(self) -> List[Experience]:
response = self.do_something()
reward_response = self.judge_model.chat.completions.create(
model=self.judge_model.model_path,
messages=[
{
"role": "system",
"content": "You are a judge. You need to give a score from 0 to 1 based on the quality of the answer.",
},
{
"role": "user",
"content": f"Question:\n{self.task.raw_task['question']}\nAnswer:\n{response.response_text}\nPlease give a score from 0 to 1.",
},
],
temperature=0.0,
max_tokens=10,
)
# Parse the reward score
reward = float(reward_response.choices[0].message.content.strip())
return [
Experience(
tokens=response.tokens,
prompt_length=response.prompt_length,
reward=reward,
logprobs=response.logprobs,
)
]
```


#### Debug Mode

During Workflow development, repeatedly launching the full training process for testing is time-consuming and inefficient. To address this, Trinity-RFT provides a Debug Mode for developers. This mode leverages a pre-launched inference model to quickly run specified workflows and obtain results, avoiding repeated model loading and initialization delays, and significantly improving development efficiency. The process is illustrated below:

```{mermaid}
flowchart LR
A[Start Inference Model] --> B[Debug Workflow]
B --> B
```

To start the inference model, use the following command:

```bash
trinity debug --config <config_file_path> --module inference_model
```

Here, `<config_file_path>` is the path to a YAML configuration file, which should follow the same format as the one used by the `trinity run` command. The `explorer.rollout_model` and `explorer.auxiliary_models` fields in the config will be loaded to initialize the inference model.

Once started, the model will keep running and wait for debug instructions; it will not exit automatically. You can then run the following command in another terminal to debug your workflow:

```bash
trinity debug --config <config_file_path> --module workflow --output_file <output_file_path> --plugin_dir <plugin_dir>
```

- `<config_file_path>`: Path to the YAML configuration file, usually the same as used for starting the inference model.
- `<output_file_path>`: Path to save the performance profiling results. Debug Mode uses [viztracer](https://github.com/gaogaotiantian/viztracer) to profile the workflow execution and saves the results as an HTML file for easy viewing in a browser.
- `<plugin_dir>` (optional): Path to the plugin directory. If your workflow or reward function modules are not built into Trinity-RFT, you can specify this parameter to load custom modules.

During debugging, the `buffer.explorer_input.taskset` field in the config will be loaded to initialize the workflow's required task dataset and instance. Note that Debug Mode only reads the first sample in the dataset for testing. After running the above command, the workflow's return value will be automatically formatted and printed in the terminal for easy inspection.

When debugging is complete, you can terminate the inference model by pressing `Ctrl+C` in its terminal.
104 changes: 104 additions & 0 deletions docs/sphinx_doc/source_zh/tutorial/develop_workflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,107 @@ class ExampleWorkflow(Workflow):
2. 调用 `chat.completions.create` 时,其中的 `model` 字段可通过 `openai_client.models.list().data[0].id` 或 `openai_client.model_path` 获取。
3. 更复杂的使用 OpenAI API 的工作流实例可参考 [ReAct Agent 训练](./example_react.md)。
```

#### LLM-as-a-judge 支持

LLM-as-a-judge 是一种常见的奖励计算方法,尤其适用于开放式任务(如编程、写作等)。在这类场景下,Workflow 需要借助额外的 LLM 来评估答案质量并计算奖励信号(reward)。

为此,Trinity-RFT 提供了 Auxiliary Models(辅助模型)机制。辅助模型是一组未参与训练的模型,Workflow 可利用这些模型辅助完成任务,例如作为评判者(judge)计算奖励。

你可以在配置文件中通过 `explorer.auxiliary_models` 字段指定一个或多个辅助模型。例如:

```yaml
explorer:
auxiliary_models:
- model_path: Qwen/Qwen2.5-32B-Instruct
engine_num: 1
tensor_parallel_size: 2
enable_thinking: false
max_prompt_tokens: 12288
max_response_tokens: 12288
max_model_len: 16384
- model_path: Qwen/Qwen3-8B
engine_num: 1
tensor_parallel_size: 2
enable_thinking: false
max_prompt_tokens: 12288
max_response_tokens: 12288
max_model_len: 16384
```

请注意,每个辅助模型会独立占用 `tensor_parallel_size * engine_num` 个 GPU,请根据硬件资源合理配置。在启用辅助模型后,Trainer 可用的 GPU 数量为总 GPU 数量减去所有辅助模型及被训练的推理模型(`rollout_model`)所占用的 GPU 数量。

配置文件中指定的辅助模型会自动激活 OpenAI API,并将对应的 `openai.OpenAI` 实例传递给 `Workflow` 初始化方法的 `auxiliary_models` 参数。例如:

```python
class MyWorkflow(Workflow):
def __init__(
self,
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.judge_model = self.auxiliary_models[0] # 使用第一个辅助模型作为评判者

def run(self) -> List[Experience]:
response = self.do_something()
reward_response = self.judge_model.chat.completions.create(
model=self.judge_model.model_path,
messages=[
{
"role": "system",
"content": "You are a judge. You need to give a score from 0 to 1 based on the quality of the answer.",
},
{
"role": "user",
"content": f"Question:\n{self.task.raw_task['question']}\nAnswer:\n{response.response_text}\nPlease give a score from 0 to 1.",
},
],
temperature=0.0,
max_tokens=10,
)
# 解析奖励分数
reward = float(reward_response.choices[0].message.content.strip())
return [
Experience(
tokens=response.tokens,
prompt_length=response.prompt_length,
reward=reward,
logprobs=response.logprobs,
)
]
```

#### 调试模式(Debug Mode)

在 Workflow 开发过程中,频繁启动完整训练流程进行测试既耗时又低效。为此,Trinity-RFT 为开发者提供了调试模式。该模式通过预先启动推理模型,能够快速运行指定的工作流并获取结果,避免因模型加载和初始化带来的重复等待,大幅提升开发效率。流程如下:

```{mermaid}
flowchart LR
A[启动推理模型] --> B[调试 Workflow]
B --> B
```

启动推理模型的命令如下:

```bash
trinity debug --config <config_file_path> --module inference_model
```

其中,`config_file_path` 为 YAML 格式的配置文件路径,格式与 `trinity run` 命令所用配置文件一致。配置文件中的 `explorer.rollout_model` 和 `explorer.auxiliary_models` 字段会被加载,用于初始化推理模型。

模型启动后会持续运行并等待调试指令,不会自动退出。此时,你可在另一个终端执行如下命令进行 Workflow 调试:

```bash
trinity debug --config <config_file_path> --module workflow --output_file <output_file_path> --plugin_dir <plugin_dir>
```

- `config_file_path`:YAML 配置文件路径,通常与启动推理模型时使用的配置文件相同。
- `output_file_path`:性能分析结果输出路径。调试模式会使用 [viztracer](https://github.com/gaogaotiantian/viztracer) 对 Workflow 运行过程进行性能分析,并将结果保存为 HTML 文件,便于在浏览器中查看。
- `plugin_dir`(可选):插件目录路径。如果你的 Workflow 或奖励函数等模块未内置于 Trinity-RFT,可通过该参数加载自定义模块。

调试过程中,配置文件中的 `buffer.explorer_input.taskset` 字段会被加载,用于初始化 Workflow 所需的任务数据集和实例。需注意,调试模式仅会读取数据集中的第一条数据进行测试。运行上述命令后,Workflow 的返回值会自动格式化并打印在终端,方便查看运行结果。

调试完成后,可在推理模型终端输入 `Ctrl+C` 以终止模型运行。
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ dev = [
"pytest-json-ctrf",
"parameterized",
"matplotlib",
"viztracer",
]
megatron = [
"megatron-core[mlm]==0.13.1",
Expand Down
74 changes: 63 additions & 11 deletions tests/cli/launcher_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import multiprocessing
import os
import shutil
import sys
import time
import unittest
from unittest import mock
from unittest.mock import MagicMock
Expand All @@ -18,6 +20,12 @@
StageConfig,
TrainerInput,
)
from trinity.common.constants import (
LOG_DIR_ENV_VAR,
LOG_LEVEL_ENV_VAR,
LOG_NODE_IP_ENV_VAR,
)
from trinity.common.models import get_debug_inference_model


class TestLauncherMain(unittest.TestCase):
Expand Down Expand Up @@ -108,9 +116,9 @@ def test_main_run_in_dlc(self, mock_init, mock_load, mock_both, mock_setup, mock
runtime_env={
"env_vars": {
launcher.PLUGIN_DIRS_ENV_VAR: "/path/to/plugins",
launcher.LOG_DIR_ENV_VAR: config.log.save_dir,
launcher.LOG_LEVEL_ENV_VAR: config.log.level,
launcher.LOG_NODE_IP_ENV_VAR: "1",
LOG_DIR_ENV_VAR: config.log.save_dir,
LOG_LEVEL_ENV_VAR: config.log.level,
LOG_NODE_IP_ENV_VAR: "1",
}
},
)
Expand Down Expand Up @@ -202,14 +210,14 @@ def test_multi_stage_run(
runtime_env={
"env_vars": {
launcher.PLUGIN_DIRS_ENV_VAR: "/path/to/plugins",
launcher.LOG_DIR_ENV_VAR: os.path.join(
LOG_DIR_ENV_VAR: os.path.join(
config.checkpoint_root_dir,
config.project,
f"{config.name}/sft_warmup",
"log",
),
launcher.LOG_LEVEL_ENV_VAR: config.log.level,
launcher.LOG_NODE_IP_ENV_VAR: "0",
LOG_LEVEL_ENV_VAR: config.log.level,
LOG_NODE_IP_ENV_VAR: "0",
}
},
),
Expand All @@ -220,14 +228,14 @@ def test_multi_stage_run(
runtime_env={
"env_vars": {
launcher.PLUGIN_DIRS_ENV_VAR: "/path/to/plugins",
launcher.LOG_DIR_ENV_VAR: os.path.join(
LOG_DIR_ENV_VAR: os.path.join(
config.checkpoint_root_dir,
config.project,
f"{config.name}/grpo",
"log",
),
launcher.LOG_LEVEL_ENV_VAR: config.log.level,
launcher.LOG_NODE_IP_ENV_VAR: "0",
LOG_LEVEL_ENV_VAR: config.log.level,
LOG_NODE_IP_ENV_VAR: "0",
}
},
),
Expand All @@ -241,6 +249,50 @@ def test_multi_stage_run(
"/path/to/hf/checkpoint",
)

@mock.patch("trinity.cli.launcher.load_config")
def test_debug_mode(self, mock_load):
process = multiprocessing.Process(target=debug_inference_model_process)
process.start()
time.sleep(15) # wait for the model to be created
for _ in range(10):
try:
get_debug_inference_model(self.config)
break
except Exception:
time.sleep(3)
output_file = os.path.join(self.config.checkpoint_job_dir, "debug.html")
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
mock_load.return_value = self.config
with mock.patch(
"argparse.ArgumentParser.parse_args",
return_value=mock.Mock(
command="debug",
config="dummy.yaml",
module="workflow",
output_file=output_file,
plugin_dir="",
),
):
launcher.main()
process.join(timeout=10)
process.terminate()
self.assertTrue(os.path.exists(output_file))

if __name__ == "__main__":
unittest.main()

def debug_inference_model_process():
config = get_template_config()
config.checkpoint_root_dir = get_checkpoint_path()
config.model.model_path = get_model_path()
config.check_and_update()
with mock.patch("trinity.cli.launcher.load_config", return_value=config):
with mock.patch(
"argparse.ArgumentParser.parse_args",
return_value=mock.Mock(
command="debug",
config="dummy.yaml",
module="inference_model",
plugin_dir=None,
output_file=None,
),
):
launcher.main()
Loading