diff --git a/docs/sphinx_doc/source/tutorial/develop_workflow.md b/docs/sphinx_doc/source/tutorial/develop_workflow.md index 816efe6997..b40c87801c 100644 --- a/docs/sphinx_doc/source/tutorial/develop_workflow.md +++ b/docs/sphinx_doc/source/tutorial/develop_workflow.md @@ -428,3 +428,108 @@ class ExampleWorkflow(Workflow): 2. When calling `chat.completions.create`, the `model` field can be obtained via `openai_client.models.list().data[0].id` or `openai_client.model_path`. 3. For more complex workflow examples using the OpenAI API, refer to [ReAct Agent Training](./example_react.md). ``` + +#### LLM-as-a-judge Support + +LLM-as-a-judge is a common reward calculation method, especially suitable for open-ended tasks (such as programming, writing, etc.). In these scenarios, the Workflow needs to leverage an additional LLM to evaluate the answer quality and compute the reward signal. + +To support this, Trinity-RFT provides an Auxiliary Models mechanism. Auxiliary models are a set of models not involved in training; the Workflow can use these models to assist with tasks, such as acting as a judge to calculate rewards. + +You can specify one or more auxiliary models in the configuration file via the `explorer.auxiliary_models` field. For example: + +```yaml +explorer: + auxiliary_models: + - model_path: Qwen/Qwen2.5-32B-Instruct + engine_num: 1 + tensor_parallel_size: 2 + enable_thinking: false + max_prompt_tokens: 12288 + max_response_tokens: 12288 + max_model_len: 16384 + - model_path: Qwen/Qwen3-8B + engine_num: 1 + tensor_parallel_size: 2 + enable_thinking: false + max_prompt_tokens: 12288 + max_response_tokens: 12288 + max_model_len: 16384 +``` + +Note that each auxiliary model will independently occupy `tensor_parallel_size * engine_num` GPUs. Please configure according to your hardware resources. After enabling auxiliary models, the number of GPUs available to the Trainer is the total GPU count minus those occupied by all auxiliary models and the inference model being trained (`rollout_model`). + +The auxiliary models specified in the configuration file will automatically activate the OpenAI API and pass the corresponding `openai.OpenAI` instances to the `auxiliary_models` parameter of the `Workflow` initialization method. For example: + +```python +class MyWorkflow(Workflow): + def __init__( + self, + *, + task: Task, + model: ModelWrapper, + auxiliary_models: Optional[List[openai.OpenAI]] = None, + ): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) + self.judge_model = self.auxiliary_models[0] # Use the first auxiliary model as the judge + + def run(self) -> List[Experience]: + response = self.do_something() + reward_response = self.judge_model.chat.completions.create( + model=self.judge_model.model_path, + messages=[ + { + "role": "system", + "content": "You are a judge. You need to give a score from 0 to 1 based on the quality of the answer.", + }, + { + "role": "user", + "content": f"Question:\n{self.task.raw_task['question']}\nAnswer:\n{response.response_text}\nPlease give a score from 0 to 1.", + }, + ], + temperature=0.0, + max_tokens=10, + ) + # Parse the reward score + reward = float(reward_response.choices[0].message.content.strip()) + return [ + Experience( + tokens=response.tokens, + prompt_length=response.prompt_length, + reward=reward, + logprobs=response.logprobs, + ) + ] +``` + + +#### Debug Mode + +During Workflow development, repeatedly launching the full training process for testing is time-consuming and inefficient. To address this, Trinity-RFT provides a Debug Mode for developers. This mode leverages a pre-launched inference model to quickly run specified workflows and obtain results, avoiding repeated model loading and initialization delays, and significantly improving development efficiency. The process is illustrated below: + +```{mermaid} +flowchart LR + A[Start Inference Model] --> B[Debug Workflow] + B --> B +``` + +To start the inference model, use the following command: + +```bash +trinity debug --config --module inference_model +``` + +Here, `` is the path to a YAML configuration file, which should follow the same format as the one used by the `trinity run` command. The `explorer.rollout_model` and `explorer.auxiliary_models` fields in the config will be loaded to initialize the inference model. + +Once started, the model will keep running and wait for debug instructions; it will not exit automatically. You can then run the following command in another terminal to debug your workflow: + +```bash +trinity debug --config --module workflow --output_file --plugin_dir +``` + +- ``: Path to the YAML configuration file, usually the same as used for starting the inference model. +- ``: Path to save the performance profiling results. Debug Mode uses [viztracer](https://github.com/gaogaotiantian/viztracer) to profile the workflow execution and saves the results as an HTML file for easy viewing in a browser. +- `` (optional): Path to the plugin directory. If your workflow or reward function modules are not built into Trinity-RFT, you can specify this parameter to load custom modules. + +During debugging, the `buffer.explorer_input.taskset` field in the config will be loaded to initialize the workflow's required task dataset and instance. Note that Debug Mode only reads the first sample in the dataset for testing. After running the above command, the workflow's return value will be automatically formatted and printed in the terminal for easy inspection. + +When debugging is complete, you can terminate the inference model by pressing `Ctrl+C` in its terminal. diff --git a/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md b/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md index d3bf275e90..3cb85787e4 100644 --- a/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md +++ b/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md @@ -428,3 +428,107 @@ class ExampleWorkflow(Workflow): 2. 调用 `chat.completions.create` 时,其中的 `model` 字段可通过 `openai_client.models.list().data[0].id` 或 `openai_client.model_path` 获取。 3. 更复杂的使用 OpenAI API 的工作流实例可参考 [ReAct Agent 训练](./example_react.md)。 ``` + +#### LLM-as-a-judge 支持 + +LLM-as-a-judge 是一种常见的奖励计算方法,尤其适用于开放式任务(如编程、写作等)。在这类场景下,Workflow 需要借助额外的 LLM 来评估答案质量并计算奖励信号(reward)。 + +为此,Trinity-RFT 提供了 Auxiliary Models(辅助模型)机制。辅助模型是一组未参与训练的模型,Workflow 可利用这些模型辅助完成任务,例如作为评判者(judge)计算奖励。 + +你可以在配置文件中通过 `explorer.auxiliary_models` 字段指定一个或多个辅助模型。例如: + +```yaml +explorer: + auxiliary_models: + - model_path: Qwen/Qwen2.5-32B-Instruct + engine_num: 1 + tensor_parallel_size: 2 + enable_thinking: false + max_prompt_tokens: 12288 + max_response_tokens: 12288 + max_model_len: 16384 + - model_path: Qwen/Qwen3-8B + engine_num: 1 + tensor_parallel_size: 2 + enable_thinking: false + max_prompt_tokens: 12288 + max_response_tokens: 12288 + max_model_len: 16384 +``` + +请注意,每个辅助模型会独立占用 `tensor_parallel_size * engine_num` 个 GPU,请根据硬件资源合理配置。在启用辅助模型后,Trainer 可用的 GPU 数量为总 GPU 数量减去所有辅助模型及被训练的推理模型(`rollout_model`)所占用的 GPU 数量。 + +配置文件中指定的辅助模型会自动激活 OpenAI API,并将对应的 `openai.OpenAI` 实例传递给 `Workflow` 初始化方法的 `auxiliary_models` 参数。例如: + +```python +class MyWorkflow(Workflow): + def __init__( + self, + *, + task: Task, + model: ModelWrapper, + auxiliary_models: Optional[List[openai.OpenAI]] = None, + ): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) + self.judge_model = self.auxiliary_models[0] # 使用第一个辅助模型作为评判者 + + def run(self) -> List[Experience]: + response = self.do_something() + reward_response = self.judge_model.chat.completions.create( + model=self.judge_model.model_path, + messages=[ + { + "role": "system", + "content": "You are a judge. You need to give a score from 0 to 1 based on the quality of the answer.", + }, + { + "role": "user", + "content": f"Question:\n{self.task.raw_task['question']}\nAnswer:\n{response.response_text}\nPlease give a score from 0 to 1.", + }, + ], + temperature=0.0, + max_tokens=10, + ) + # 解析奖励分数 + reward = float(reward_response.choices[0].message.content.strip()) + return [ + Experience( + tokens=response.tokens, + prompt_length=response.prompt_length, + reward=reward, + logprobs=response.logprobs, + ) + ] +``` + +#### 调试模式(Debug Mode) + +在 Workflow 开发过程中,频繁启动完整训练流程进行测试既耗时又低效。为此,Trinity-RFT 为开发者提供了调试模式。该模式通过预先启动推理模型,能够快速运行指定的工作流并获取结果,避免因模型加载和初始化带来的重复等待,大幅提升开发效率。流程如下: + +```{mermaid} +flowchart LR + A[启动推理模型] --> B[调试 Workflow] + B --> B +``` + +启动推理模型的命令如下: + +```bash +trinity debug --config --module inference_model +``` + +其中,`config_file_path` 为 YAML 格式的配置文件路径,格式与 `trinity run` 命令所用配置文件一致。配置文件中的 `explorer.rollout_model` 和 `explorer.auxiliary_models` 字段会被加载,用于初始化推理模型。 + +模型启动后会持续运行并等待调试指令,不会自动退出。此时,你可在另一个终端执行如下命令进行 Workflow 调试: + +```bash +trinity debug --config --module workflow --output_file --plugin_dir +``` + +- `config_file_path`:YAML 配置文件路径,通常与启动推理模型时使用的配置文件相同。 +- `output_file_path`:性能分析结果输出路径。调试模式会使用 [viztracer](https://github.com/gaogaotiantian/viztracer) 对 Workflow 运行过程进行性能分析,并将结果保存为 HTML 文件,便于在浏览器中查看。 +- `plugin_dir`(可选):插件目录路径。如果你的 Workflow 或奖励函数等模块未内置于 Trinity-RFT,可通过该参数加载自定义模块。 + +调试过程中,配置文件中的 `buffer.explorer_input.taskset` 字段会被加载,用于初始化 Workflow 所需的任务数据集和实例。需注意,调试模式仅会读取数据集中的第一条数据进行测试。运行上述命令后,Workflow 的返回值会自动格式化并打印在终端,方便查看运行结果。 + +调试完成后,可在推理模型终端输入 `Ctrl+C` 以终止模型运行。 diff --git a/pyproject.toml b/pyproject.toml index 79d3fa8e80..3723fc73f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ dev = [ "pytest-json-ctrf", "parameterized", "matplotlib", + "viztracer", ] megatron = [ "megatron-core[mlm]==0.13.1", diff --git a/tests/cli/launcher_test.py b/tests/cli/launcher_test.py index dca5da8759..067efb4a6f 100644 --- a/tests/cli/launcher_test.py +++ b/tests/cli/launcher_test.py @@ -1,6 +1,8 @@ +import multiprocessing import os import shutil import sys +import time import unittest from unittest import mock from unittest.mock import MagicMock @@ -18,6 +20,12 @@ StageConfig, TrainerInput, ) +from trinity.common.constants import ( + LOG_DIR_ENV_VAR, + LOG_LEVEL_ENV_VAR, + LOG_NODE_IP_ENV_VAR, +) +from trinity.common.models import get_debug_inference_model class TestLauncherMain(unittest.TestCase): @@ -108,9 +116,9 @@ def test_main_run_in_dlc(self, mock_init, mock_load, mock_both, mock_setup, mock runtime_env={ "env_vars": { launcher.PLUGIN_DIRS_ENV_VAR: "/path/to/plugins", - launcher.LOG_DIR_ENV_VAR: config.log.save_dir, - launcher.LOG_LEVEL_ENV_VAR: config.log.level, - launcher.LOG_NODE_IP_ENV_VAR: "1", + LOG_DIR_ENV_VAR: config.log.save_dir, + LOG_LEVEL_ENV_VAR: config.log.level, + LOG_NODE_IP_ENV_VAR: "1", } }, ) @@ -202,14 +210,14 @@ def test_multi_stage_run( runtime_env={ "env_vars": { launcher.PLUGIN_DIRS_ENV_VAR: "/path/to/plugins", - launcher.LOG_DIR_ENV_VAR: os.path.join( + LOG_DIR_ENV_VAR: os.path.join( config.checkpoint_root_dir, config.project, f"{config.name}/sft_warmup", "log", ), - launcher.LOG_LEVEL_ENV_VAR: config.log.level, - launcher.LOG_NODE_IP_ENV_VAR: "0", + LOG_LEVEL_ENV_VAR: config.log.level, + LOG_NODE_IP_ENV_VAR: "0", } }, ), @@ -220,14 +228,14 @@ def test_multi_stage_run( runtime_env={ "env_vars": { launcher.PLUGIN_DIRS_ENV_VAR: "/path/to/plugins", - launcher.LOG_DIR_ENV_VAR: os.path.join( + LOG_DIR_ENV_VAR: os.path.join( config.checkpoint_root_dir, config.project, f"{config.name}/grpo", "log", ), - launcher.LOG_LEVEL_ENV_VAR: config.log.level, - launcher.LOG_NODE_IP_ENV_VAR: "0", + LOG_LEVEL_ENV_VAR: config.log.level, + LOG_NODE_IP_ENV_VAR: "0", } }, ), @@ -241,6 +249,50 @@ def test_multi_stage_run( "/path/to/hf/checkpoint", ) + @mock.patch("trinity.cli.launcher.load_config") + def test_debug_mode(self, mock_load): + process = multiprocessing.Process(target=debug_inference_model_process) + process.start() + time.sleep(15) # wait for the model to be created + for _ in range(10): + try: + get_debug_inference_model(self.config) + break + except Exception: + time.sleep(3) + output_file = os.path.join(self.config.checkpoint_job_dir, "debug.html") + self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") + mock_load.return_value = self.config + with mock.patch( + "argparse.ArgumentParser.parse_args", + return_value=mock.Mock( + command="debug", + config="dummy.yaml", + module="workflow", + output_file=output_file, + plugin_dir="", + ), + ): + launcher.main() + process.join(timeout=10) + process.terminate() + self.assertTrue(os.path.exists(output_file)) -if __name__ == "__main__": - unittest.main() + +def debug_inference_model_process(): + config = get_template_config() + config.checkpoint_root_dir = get_checkpoint_path() + config.model.model_path = get_model_path() + config.check_and_update() + with mock.patch("trinity.cli.launcher.load_config", return_value=config): + with mock.patch( + "argparse.ArgumentParser.parse_args", + return_value=mock.Mock( + command="debug", + config="dummy.yaml", + module="inference_model", + plugin_dir=None, + output_file=None, + ), + ): + launcher.main() diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 3b88e85cf7..649bb348a0 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -1,5 +1,6 @@ """Launch the trainer""" import argparse +import asyncio import os import sys import traceback @@ -10,12 +11,7 @@ from trinity.buffer.pipelines.task_pipeline import check_and_run_task_pipeline from trinity.common.config import Config, load_config -from trinity.common.constants import ( - LOG_DIR_ENV_VAR, - LOG_LEVEL_ENV_VAR, - LOG_NODE_IP_ENV_VAR, - PLUGIN_DIRS_ENV_VAR, -) +from trinity.common.constants import DEBUG_NAMESPACE, PLUGIN_DIRS_ENV_VAR from trinity.explorer.explorer import Explorer from trinity.manager.state_manager import StateManager from trinity.trainer.trainer import Trainer @@ -147,17 +143,11 @@ def both(config: Config) -> None: def run_stage(config: Config, ray_address: str) -> None: - envs = { - PLUGIN_DIRS_ENV_VAR: os.environ.get(PLUGIN_DIRS_ENV_VAR, ""), - LOG_DIR_ENV_VAR: config.log.save_dir, - LOG_LEVEL_ENV_VAR: config.log.level, - LOG_NODE_IP_ENV_VAR: "1" if config.log.group_by_node else "0", - } ray.init( address=ray_address, ignore_reinit_error=True, namespace=config.ray_namespace, - runtime_env={"env_vars": envs}, + runtime_env={"env_vars": config.get_envs()}, ) pprint(config) try: @@ -247,6 +237,40 @@ def studio(port: int = 8501): sys.exit(stcli.main()) +def debug( + config_path: str, + module: str, + output_file: str = "debug_workflow_runner.html", + plugin_dir: str = None, +): + """Debug a module.""" + if plugin_dir: + os.environ[PLUGIN_DIRS_ENV_VAR] = plugin_dir + load_plugins() + config = load_config(config_path) + config.check_and_update() + config.ray_namespace = DEBUG_NAMESPACE + ray.init( + namespace=config.ray_namespace, + runtime_env={"env_vars": config.get_envs()}, + ignore_reinit_error=True, + ) + from trinity.common.models import create_debug_inference_model + + if module == "inference_model": + create_debug_inference_model(config) + + elif module == "workflow": + from trinity.explorer.workflow_runner import DebugWorkflowRunner + + runner = DebugWorkflowRunner(config, output_file) + asyncio.run(runner.debug()) + else: + raise ValueError( + f"Only support 'inference_model' and 'workflow' for debugging, got {module}" + ) + + def main() -> None: """The main entrypoint.""" parser = argparse.ArgumentParser() @@ -271,12 +295,36 @@ def main() -> None: "--port", type=int, default=8501, help="The port for Trinity-Studio." ) + # debug command + debug_parser = subparsers.add_parser("debug", help="Debug the code.") + debug_parser.add_argument("--config", type=str, help="Path to the config file.") + debug_parser.add_argument( + "--module", + type=str, + choices=["inference_model", "workflow"], + help="The module to start debugging, only support 'inference_model' and 'workflow' for now.", + ) + debug_parser.add_argument( + "--plugin-dir", + type=str, + default=None, + help="Path to the directory containing plugin modules.", + ) + debug_parser.add_argument( + "--output-file", + type=str, + default="debug_workflow_runner.html", + help="The output file for viztracer.", + ) + args = parser.parse_args() if args.command == "run": # TODO: support parse all args from command line run(args.config, args.dlc, args.plugin_dir) elif args.command == "studio": studio(args.port) + elif args.command == "debug": + debug(args.config, args.module, args.output_file, args.plugin_dir) if __name__ == "__main__": diff --git a/trinity/common/config.py b/trinity/common/config.py index 6946416a27..7d97a7a7af 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -13,7 +13,11 @@ from trinity.common.constants import ( EXPLORER_NAME, + LOG_DIR_ENV_VAR, + LOG_LEVEL_ENV_VAR, + LOG_NODE_IP_ENV_VAR, MAX_MODEL_LEN, + PLUGIN_DIRS_ENV_VAR, TRAINER_NAME, PromptType, StorageType, @@ -1090,6 +1094,15 @@ def _flatten(obj, parent_key="", sep="."): return _flatten(self) + def get_envs(self) -> Dict[str, str]: + """Get the environment variables from the config.""" + return { + PLUGIN_DIRS_ENV_VAR: os.getenv(PLUGIN_DIRS_ENV_VAR, ""), + LOG_LEVEL_ENV_VAR: self.log.level, + LOG_DIR_ENV_VAR: self.log.save_dir, + LOG_NODE_IP_ENV_VAR: "1" if self.log.group_by_node else "0", + } + def load_config(config_path: str) -> Config: """Load the configuration from the given path.""" diff --git a/trinity/common/constants.py b/trinity/common/constants.py index 494d418f93..a457729862 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -8,6 +8,7 @@ TRAINER_NAME = "trainer" ROLLOUT_WEIGHT_SYNC_GROUP_NAME = "rollout_weight_sync" +DEBUG_NAMESPACE = "TRINITY_DEBUG_NAMESPACE" # trinity env var names CHECKPOINT_ROOT_DIR_ENV_VAR = "TRINITY_CHECKPOINT_ROOT_DIR" diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py index 964e90684c..9b12b48340 100644 --- a/trinity/common/models/__init__.py +++ b/trinity/common/models/__init__.py @@ -1,7 +1,9 @@ +import time from collections import defaultdict from typing import List, Tuple from trinity.common.config import Config +from trinity.common.constants import DEBUG_NAMESPACE from trinity.common.models.model import InferenceModel from trinity.utils.log import get_logger @@ -82,7 +84,7 @@ def create_inference_models( allocator = _BundleAllocator(node_bundle_map) namespace = ray.get_runtime_context().namespace # create rollout models - for _ in range(config.explorer.rollout_model.engine_num): + for i in range(config.explorer.rollout_model.engine_num): bundles_for_engine = allocator.allocate(config.explorer.rollout_model.tensor_parallel_size) config.explorer.rollout_model.bundle_indices = ",".join( [str(bid) for bid in bundles_for_engine] @@ -90,6 +92,7 @@ def create_inference_models( rollout_engines.append( ray.remote(engine_cls) .options( + name=f"{config.explorer.name}_rollout_model_{i}", num_cpus=0, num_gpus=0 if config.explorer.rollout_model.tensor_parallel_size > 1 else 1, namespace=namespace, @@ -112,9 +115,9 @@ def create_inference_models( "history via `extract_experience_from_history` to avoid out-of-memory issues." ) # create auxiliary models - for model_config in config.explorer.auxiliary_models: + for i, model_config in enumerate(config.explorer.auxiliary_models): engines = [] - for _ in range(model_config.engine_num): + for j in range(model_config.engine_num): bundles_for_engine = allocator.allocate(model_config.tensor_parallel_size) model_config.enable_openai_api = True model_config.engine_type = "vllm" @@ -122,6 +125,7 @@ def create_inference_models( engines.append( ray.remote(vLLMRolloutModel) .options( + name=f"{config.explorer.name}_auxiliary_model_{i}_{j}", num_cpus=0, num_gpus=0 if model_config.tensor_parallel_size > 1 else 1, namespace=namespace, @@ -140,3 +144,52 @@ def create_inference_models( engine.run_api_server.remote() return rollout_engines, auxiliary_engines + + +def create_debug_inference_model(config: Config) -> None: + """Create inference models for debugging.""" + import ray + + logger = get_logger(__name__) + logger.info("Creating inference models for debugging...") + # only create one engine for each model + config.explorer.rollout_model.engine_num = 1 + for model in config.explorer.auxiliary_models: + model.engine_num = 1 + rollout_models, auxiliary_models = create_inference_models(config) + # make sure models are started + for m in rollout_models: + ray.get(m.get_model_path.remote()) + for models in auxiliary_models: + for m in models: + ray.get(m.get_model_path.remote()) + logger.info( + "----------------------------------------------------\n" + "Inference models started successfully for debugging.\n" + "Press Ctrl+C to exit.\n" + "----------------------------------------------------" + ) + + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + logger.info("Exiting...") + + +def get_debug_inference_model(config: Config) -> Tuple[InferenceModel, List[InferenceModel]]: + """Get the inference models for debugging. + The models must be created by `create_debug_inference_model` in another process first. + """ + import ray + + rollout_model = ray.get_actor( + f"{config.explorer.name}_rollout_model_0", namespace=DEBUG_NAMESPACE + ) + auxiliary_models = [] + for i in range(len(config.explorer.auxiliary_models)): + model = ray.get_actor( + f"{config.explorer.name}_auxiliary_model_{i}_0", namespace=DEBUG_NAMESPACE + ) + auxiliary_models.append(model) + return rollout_model, auxiliary_models diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index d841e0e625..7a473137f0 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -7,8 +7,10 @@ from dataclasses import dataclass from typing import List, Optional, Tuple +from trinity.buffer import get_buffer_reader from trinity.common.config import Config from trinity.common.experience import Experience +from trinity.common.models import get_debug_inference_model from trinity.common.models.model import InferenceModel, ModelWrapper from trinity.common.workflows import Task, Workflow from trinity.utils.log import get_logger @@ -149,3 +151,35 @@ async def run_task( error_trace_back = traceback.format_exc() self.logger.error(f"WorkflowRunner run task error: {e}\nTraceback:\n{error_trace_back}") return Status(False, metric={"time_per_task": time.time() - st}, message=str(e)), [] + + +class DebugWorkflowRunner(WorkflowRunner): + """A WorkflowRunner for debugging.""" + + def __init__( + self, + config: Config, + output_file: str, + ) -> None: + model, auxiliary_models = get_debug_inference_model(config) + super().__init__(config, model, auxiliary_models, 0) + self.taskset = get_buffer_reader(config.buffer.explorer_input.taskset, config.buffer) + self.output_file = output_file + + async def debug(self) -> None: + """Run the debug workflow.""" + from viztracer import VizTracer + + await self.prepare() + tasks = await self.taskset.read_async(batch_size=1) + task = tasks[0] + self.logger.info(f"Read task: {task.task_id}, repeat_times: {task.repeat_times}") + with VizTracer(output_file=self.output_file): + status, exps = await self.run_task(task, task.repeat_times, 0) + if status.ok: + print(f"Task {task.task_id} completed successfully with metrics:\n{status.metric}") + for exp in exps: + print(f"Generated experience:\n{exp}") + else: + self.logger.error(f"Task {task.task_id} failed with message: {status.message}") + self.logger.info("Debugging completed.")