diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index 6a3f5211fd..80d3bb6029 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -119,7 +119,7 @@ jobs: run: | REPORT=report.json if [ -f "$REPORT" ]; then - jq '(.results.tests[] | .duration, .start, .stop) |= (. * 1000) | (.results.summary.start, .results.summary.stop) |= (. * 1000)' "$REPORT" > "$REPORT.tmp" && mv "$REPORT.tmp" "$REPORT" + jq '(.results.tests[] | .start, .stop) |= (. * 1000) | (.results.summary.start, .results.summary.stop) |= (. * 1000)' "$REPORT" > "$REPORT.tmp" && mv "$REPORT.tmp" "$REPORT" fi - name: Clean checkpoint dir diff --git a/pyproject.toml b/pyproject.toml index 0527fc5da4..efd60b56cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ "matplotlib", "transformers>=4.51.0", "datasets>=4.0.0", + "typer>=0.20.1", ] [project.scripts] diff --git a/tests/cli/launcher_test.py b/tests/cli/launcher_test.py index 412c81fda2..68ab688183 100644 --- a/tests/cli/launcher_test.py +++ b/tests/cli/launcher_test.py @@ -7,6 +7,8 @@ from unittest import mock from unittest.mock import MagicMock +from typer.testing import CliRunner as TyperCliRunner + from tests.tools import ( get_checkpoint_path, get_model_path, @@ -27,6 +29,8 @@ ) from trinity.common.models import get_debug_explorer_model +runner = TyperCliRunner() + class TestLauncherMain(unittest.TestCase): def setUp(self): @@ -73,13 +77,11 @@ def test_main_run_command( for mode in ["explore", "train", "both", "bench", "serve"]: config.mode = mode mock_load.return_value = config - with mock.patch( - "argparse.ArgumentParser.parse_args", - return_value=mock.Mock( - command="run", config="dummy.yaml", dlc=False, plugin_dir=None - ), - ): - launcher.main() + result = runner.invoke( + launcher.app, + ["run", "--config", "dummy.yaml"], + ) + self.assertEqual(result.exit_code, 0, msg=result.output) mock_load.assert_called_once_with("dummy.yaml") mapping[mode].assert_called_once_with(config) mock_load.reset_mock() @@ -104,13 +106,11 @@ def test_main_run_in_dlc(self, mock_init, mock_load, mock_both, mock_setup, mock "both": mock_both, }, ): - with mock.patch( - "argparse.ArgumentParser.parse_args", - return_value=mock.Mock( - command="run", config="dummy.yaml", dlc=True, plugin_dir="/path/to/plugins" - ), - ): - launcher.main() + result = runner.invoke( + launcher.app, + ["run", "--config", "dummy.yaml", "--dlc", "--plugin-dir", "/path/to/plugins"], + ) + self.assertEqual(result.exit_code, 0, msg=result.output) mock_init.assert_called_once() mock_init.assert_called_once_with( address="auto", @@ -134,14 +134,20 @@ def test_main_run_in_dlc(self, mock_init, mock_load, mock_both, mock_setup, mock namespace=namespace, ) - @mock.patch("trinity.cli.launcher.studio") - def test_main_studio_command(self, mock_studio): - with mock.patch( - "argparse.ArgumentParser.parse_args", - return_value=mock.Mock(command="studio", port=9999), - ): - launcher.main() - mock_studio.assert_called_once_with(9999) + @mock.patch("trinity.manager.config_manager.ConfigManager.run") + def test_main_studio_command(self, mock_studio_fn): + result = runner.invoke( + launcher.app, + ["studio", "--port", "9999"], + ) + self.assertEqual(result.exit_code, 0, msg=result.output) + mock_studio_fn.assert_called_once() + # Typer calls the function with keyword args; verify port was passed + call_kwargs = mock_studio_fn.call_args + # The typer-decorated function receives port=9999 + self.assertEqual( + call_kwargs.kwargs.get("port", call_kwargs.args[0] if call_kwargs.args else None), 9999 + ) @mock.patch("trinity.trainer.verl.utils.get_latest_hf_checkpoint_path") @mock.patch("trinity.cli.launcher.both") @@ -194,13 +200,17 @@ def test_multi_stage_run( "both": mock_both, }, ): - with mock.patch( - "argparse.ArgumentParser.parse_args", - return_value=mock.Mock( - command="run", config="dummy.yaml", dlc=False, plugin_dir="/path/to/plugins" - ), - ): - launcher.main() + result = runner.invoke( + launcher.app, + [ + "run", + "--config", + "dummy.yaml", + "--plugin-dir", + "/path/to/plugins", + ], + ) + self.assertEqual(result.exit_code, 0, msg=result.output) self.assertEqual(mock_init.call_count, 2) self.assertEqual(mock_shutdown.call_count, 2) mock_train.assert_called_once() @@ -268,62 +278,66 @@ def test_debug_mode(self, mock_load): output_dir = os.path.join(self.config.checkpoint_job_dir, "debug_output") self.config.buffer.explorer_input.tasksets = [get_unittest_dataset_config("gsm8k")] mock_load.return_value = self.config - with mock.patch( - "argparse.ArgumentParser.parse_args", - return_value=mock.Mock( - command="debug", - config="dummy.yaml", - module="workflow", - enable_profiling=True, - disable_overwrite=False, - output_dir=output_dir, - output_file=output_file, - plugin_dir="", - ), - ): - launcher.main() + + # First run: workflow with profiling enabled + result = runner.invoke( + launcher.app, + [ + "debug", + "--config", + "dummy.yaml", + "--module", + "workflow", + "--enable-profiling", + "--output-dir", + output_dir, + ], + ) + self.assertEqual(result.exit_code, 0, msg=result.output) self.assertFalse(os.path.exists(output_file)) self.assertTrue(os.path.exists(output_dir)) self.assertTrue(os.path.exists(os.path.join(output_dir, "profiling.html"))) self.assertTrue(os.path.exists(os.path.join(output_dir, "experiences.db"))) + # add a dummy file to test overwrite behavior with open(os.path.join(output_dir, "dummy.txt"), "w") as f: f.write("not empty") - with mock.patch( - "argparse.ArgumentParser.parse_args", - return_value=mock.Mock( - command="debug", - config="dummy.yaml", - module="workflow", - enable_profiling=False, - disable_overwrite=False, - output_dir=output_dir, - output_file=output_file, - plugin_dir="", - ), - ): - launcher.main() + # Second run: workflow without profiling, overwrite allowed (default) + result = runner.invoke( + launcher.app, + [ + "debug", + "--config", + "dummy.yaml", + "--module", + "workflow", + "--output-dir", + output_dir, + ], + ) + self.assertEqual(result.exit_code, 0, msg=result.output) dirs = os.listdir(self.config.checkpoint_job_dir) target_output_dir = [d for d in dirs if d.startswith("debug_output_")] self.assertEqual(len(target_output_dir), 0) - with mock.patch( - "argparse.ArgumentParser.parse_args", - return_value=mock.Mock( - command="debug", - config="dummy.yaml", - module="workflow", - enable_profiling=False, - disable_overwrite=True, - output_dir=output_dir, - output_file=output_file, - plugin_dir="", - ), - ): - launcher.main() + # Third run: workflow without profiling, overwrite disabled + result = runner.invoke( + launcher.app, + [ + "debug", + "--config", + "dummy.yaml", + "--module", + "workflow", + "--disable-overwrite", + "--output-dir", + output_dir, + ], + ) + self.assertEqual(result.exit_code, 0, msg=result.output) self.assertFalse(os.path.exists(output_file)) # test the original files are not overwritten @@ -357,14 +371,13 @@ def debug_inference_model_process(): config.model.model_path = get_model_path() config.check_and_update() with mock.patch("trinity.cli.launcher.load_config", return_value=config): - with mock.patch( - "argparse.ArgumentParser.parse_args", - return_value=mock.Mock( - command="debug", - config="dummy.yaml", - module="inference_model", - plugin_dir=None, - output_file=None, - ), - ): - launcher.main() + runner.invoke( + launcher.app, + [ + "debug", + "--config", + "dummy.yaml", + "--module", + "inference_model", + ], + ) diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index 21a26bdd9b..aa4ad21155 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -680,7 +680,9 @@ async def test_workflow_runner(self): workflow_args={"output_format": "json"}, ) - status, exps = await runner.run_task(task, repeat_times=3, run_id_base=0) + status, exps = await runner.run_task( + task, batch_id="test", repeat_times=3, run_id_base=0 + ) self.assertTrue(status.ok) self.assertIsInstance(exps, list) @@ -693,7 +695,9 @@ async def test_workflow_runner(self): workflow_args={"output_format": "yaml"}, ) - status, exps = await runner.run_task(task, repeat_times=2, run_id_base=0) + status, exps = await runner.run_task( + task, batch_id="test", repeat_times=2, run_id_base=0 + ) self.assertTrue(status.ok) self.assertIsInstance(exps, list) self.assertEqual(len(exps), 2) @@ -750,7 +754,10 @@ async def monitor_routine(): return count await asyncio.gather( - *[monitor_routine(), runner.run_task(task, repeat_times=3, run_id_base=0)] + *[ + monitor_routine(), + runner.run_task(task, batch_id="test", repeat_times=3, run_id_base=0), + ] ) async def test_workflow_with_openai(self): @@ -784,13 +791,15 @@ async def test_workflow_with_openai(self): ] status, exps = await runner.run_task( - tasks[0], repeat_times=2, run_id_base=0 + tasks[0], batch_id="test", repeat_times=2, run_id_base=0 ) # test exception handling self.assertEqual(status.ok, False) self.assertEqual(len(exps), 0) exps = runner.model_wrapper.extract_experience_from_history(clear_history=False) self.assertEqual(len(exps), 1) - status, exps = await runner.run_task(tasks[1], repeat_times=2, run_id_base=0) # normal run + status, exps = await runner.run_task( + tasks[1], batch_id="test", repeat_times=2, run_id_base=0 + ) # normal run self.assertEqual(status.ok, True) self.assertEqual(len(exps), 2) exps = runner.model_wrapper.extract_experience_from_history(clear_history=False) @@ -870,19 +879,23 @@ async def test_concurrent_workflow_runner(self): raw_task={"text": "Hello, world!"}, ) # warmup - async_status, async_exps = await async_runner.run_task(task, repeat_times=2, run_id_base=0) + async_status, async_exps = await async_runner.run_task( + task, batch_id="test", repeat_times=2, run_id_base=0 + ) st = time.time() - async_status, async_exps = await async_runner.run_task(task, repeat_times=4, run_id_base=0) + async_status, async_exps = await async_runner.run_task( + task, batch_id="test", repeat_times=4, run_id_base=0 + ) async_runtime = time.time() - st st = time.time() thread_status, thread_exps = await thread_runner.run_task( - task, repeat_times=4, run_id_base=0 + task, batch_id="test", repeat_times=4, run_id_base=0 ) thread_runtime = time.time() - st st = time.time() sequential_status, sequential_exps = await sequential_runner.run_task( - task, repeat_times=4, run_id_base=0 + task, batch_id="test", repeat_times=4, run_id_base=0 ) sequential_runtime = time.time() - st diff --git a/trinity/buffer/viewer.py b/trinity/buffer/viewer.py index 86d75240e7..8a6af17364 100644 --- a/trinity/buffer/viewer.py +++ b/trinity/buffer/viewer.py @@ -1,4 +1,6 @@ import argparse +import sys +from pathlib import Path from typing import List import streamlit as st @@ -37,6 +39,38 @@ def total_experiences(self) -> int: count = session.query(self.table_model_cls).count() return count + @staticmethod + def run_viewer(model_path: str, db_url: str, table_name: str, port: int): + """Start the Streamlit viewer. + + Args: + model_path (str): Path to the tokenizer/model directory. + db_url (str): Database URL for the experience database. + table_name (str): Name of the experience table in the database. + port (int): Port number to run the Streamlit app on. + """ + + from streamlit.web import cli + + viewer_path = Path(__file__) + sys.argv = [ + "streamlit", + "run", + str(viewer_path.resolve()), + "--server.port", + str(port), + "--server.fileWatcherType", + "none", + "--", + "--db-url", + db_url, + "--table", + table_name, + "--tokenizer", + model_path, + ] + sys.exit(cli.main()) + st.set_page_config(page_title="Trinity-RFT Experience Visualizer", layout="wide") diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 3f82520988..438eb205b2 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -1,31 +1,32 @@ """Launch the trainer""" -import argparse import asyncio import os import sys import traceback -from pathlib import Path from pprint import pprint from typing import Optional import ray +import typer +from typing_extensions import Annotated -from trinity.buffer.pipelines.task_pipeline import check_and_run_task_pipeline from trinity.common.config import Config, load_config from trinity.common.constants import DEBUG_NAMESPACE, PLUGIN_DIRS_ENV_VAR -from trinity.explorer.explorer import Explorer from trinity.manager.checkpoint_converter import Converter from trinity.manager.state_manager import StateManager -from trinity.trainer.trainer import Trainer from trinity.utils.dlc_utils import is_running, setup_ray_cluster, stop_ray_cluster from trinity.utils.log import get_logger from trinity.utils.plugin_loader import load_plugins logger = get_logger(__name__) +app = typer.Typer(help="Trinity CLI - Launch and manage Trinity-RFT processes.") + def bench(config: Config) -> None: """Evaluate model.""" + from trinity.explorer.explorer import Explorer + config.explorer.name = "benchmark" try: explorer = Explorer.get_actor(config) @@ -39,6 +40,8 @@ def bench(config: Config) -> None: def explore(config: Config) -> None: """Run explorer.""" + from trinity.explorer.explorer import Explorer + try: explorer = Explorer.get_actor(config) ray.get(explorer.prepare.remote()) @@ -51,6 +54,8 @@ def explore(config: Config) -> None: def train(config: Config) -> None: """Run trainer.""" + from trinity.trainer.trainer import Trainer + try: trainer = Trainer.get_actor(config) ray.get(trainer.prepare.remote()) @@ -63,6 +68,8 @@ def train(config: Config) -> None: def serve(config: Config) -> None: """Run explorer in server mode.""" + from trinity.explorer.explorer import Explorer + try: explorer = Explorer.get_actor(config) ray.get(explorer.prepare.remote()) @@ -83,6 +90,9 @@ def both(config: Config) -> None: the latest step. The specific number of experiences may vary for different algorithms and tasks. """ + from trinity.explorer.explorer import Explorer + from trinity.trainer.trainer import Trainer + try: explorer = Explorer.get_actor(config) trainer = Trainer.get_actor(config) @@ -154,6 +164,8 @@ def run_stage(config: Config) -> None: ) pprint(config) try: + from trinity.buffer.pipelines.task_pipeline import check_and_run_task_pipeline + check_and_run_task_pipeline(config) MODE_MAP[config.mode](config) finally: @@ -165,40 +177,60 @@ def run_stage(config: Config) -> None: ray.shutdown() -def run(config_path: str, dlc: bool = False, plugin_dir: str = None): +# --------------------------------------------------------------------------- +# CLI commands +# --------------------------------------------------------------------------- + + +@app.command() +def run( + config: Annotated[ + str, + typer.Option("--config", help="Path to the config file."), + ], + dlc: Annotated[ + bool, + typer.Option("--dlc", help="Specify when running in Aliyun PAI DLC."), + ] = False, + plugin_dir: Annotated[ + Optional[str], + typer.Option("--plugin-dir", help="Path to the directory containing plugin modules."), + ] = None, +) -> None: + """Run RFT process.""" if plugin_dir: os.environ[PLUGIN_DIRS_ENV_VAR] = plugin_dir load_plugins() - config = load_config(config_path) + cfg = load_config(config) if dlc: - cluster_namespace = f"{config.project}-{config.name}" - config.cluster.ray_address = setup_ray_cluster(namespace=cluster_namespace) + cluster_namespace = f"{cfg.project}-{cfg.name}" + cfg.cluster.ray_address = setup_ray_cluster(namespace=cluster_namespace) if not is_running(): raise RuntimeError("Ray is not running, please start it by `ray start --head`.") try: - if config.stages: + if cfg.stages: from trinity.trainer.verl.utils import get_latest_hf_checkpoint_path state_manager = StateManager( - path=os.path.join(config.checkpoint_root_dir, config.project, config.name) + path=os.path.join(cfg.checkpoint_root_dir, cfg.project, cfg.name) ) latest_stage = state_manager.load_stage().get("latest_stage", 0) prev_stage_checkpoint = None - for i, stage_config in enumerate(config): + for i, stage_config in enumerate(cfg): if i < latest_stage: logger.info( "===========================================================\n" - f"> Skipping completed stage {i + 1}/{len(config.stages)}...\n" + f"> Skipping completed stage {i + 1}/{len(cfg.stages)}...\n" "===========================================================" ) stage_config.check_and_update() else: logger.info( "===========================================================\n" - f"> Starting stage {i + 1}/{len(config.stages)}...\n" + f"> Starting stage {i + 1}/{len(cfg.stages)}...\n" "===========================================================" ) state_manager.save_stage(i) @@ -208,212 +240,132 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): run_stage(stage_config) logger.info( "===========================================================\n" - f"> Stage {i + 1}/{len(config.stages)} finished.\n" + f"> Stage {i + 1}/{len(cfg.stages)} finished.\n" "===========================================================" ) prev_stage_checkpoint = get_latest_hf_checkpoint_path(stage_config) else: - config.check_and_update() - run_stage(config) + cfg.check_and_update() + run_stage(cfg) finally: if dlc: stop_ray_cluster(namespace=cluster_namespace) -def studio(port: int = 8501): - from streamlit.web import cli as stcli +@app.command() +def studio( + port: Annotated[ + int, + typer.Option("--port", help="The port for Trinity-Studio."), + ] = 8501, +) -> None: + """Run studio to manage configurations.""" + from trinity.manager.config_manager import ConfigManager - current_dir = Path(__file__).resolve().parent.parent - config_manager_path = os.path.join(current_dir, "manager", "config_manager.py") - - sys.argv = [ - "streamlit", - "run", - config_manager_path, - "--server.port", - str(port), - "--server.fileWatcherType", - "none", - ] - sys.exit(stcli.main()) + ConfigManager.run(port) +@app.command() def debug( - config_path: str, - module: str, - output_dir: str = "debug_output", - disable_overwrite: bool = False, - enable_profiling: bool = False, - port: int = 8502, - plugin_dir: str = None, -): - """Debug a module.""" + config: Annotated[ + str, + typer.Option("--config", help="Path to the config file."), + ], + module: Annotated[ + str, + typer.Option( + "--module", + help="The module to debug: 'inference_model', 'workflow', or 'viewer'.", + ), + ], + plugin_dir: Annotated[ + Optional[str], + typer.Option("--plugin-dir", help="Path to the directory containing plugin modules."), + ] = None, + output_dir: Annotated[ + str, + typer.Option("--output-dir", help="The output directory for debug files."), + ] = "debug_output", + disable_overwrite: Annotated[ + bool, + typer.Option("--disable-overwrite", help="Disable overwriting the output directory."), + ] = False, + enable_profiling: Annotated[ + bool, + typer.Option("--enable-profiling", help="Whether to use viztracer for workflow profiling."), + ] = False, + port: Annotated[ + int, + typer.Option("--port", help="The port for Experience Viewer."), + ] = 8502, +) -> None: + """Debug a workflow implementation.""" + valid_modules = ("inference_model", "workflow", "viewer") + if module not in valid_modules: + raise typer.BadParameter(f"Only support {valid_modules} for debugging, got '{module}'") + if plugin_dir: os.environ[PLUGIN_DIRS_ENV_VAR] = plugin_dir load_plugins() - config = load_config(config_path) - config.mode = "explore" - config.ray_namespace = DEBUG_NAMESPACE - config.check_and_update() + cfg = load_config(config) + cfg.mode = "explore" + cfg.ray_namespace = DEBUG_NAMESPACE + cfg.check_and_update() sys.path.insert(0, os.getcwd()) ray.init( - namespace=config.ray_namespace, - runtime_env={"env_vars": config.get_envs()}, + namespace=cfg.ray_namespace, + runtime_env={"env_vars": cfg.get_envs()}, ignore_reinit_error=True, ) + from trinity.common.models import create_debug_explorer_model if module == "inference_model": - asyncio.run(create_debug_explorer_model(config)) + asyncio.run(create_debug_explorer_model(cfg)) elif module == "workflow": from trinity.explorer.workflow_runner import DebugWorkflowRunner - runner = DebugWorkflowRunner(config, output_dir, enable_profiling, disable_overwrite) + runner = DebugWorkflowRunner(cfg, output_dir, enable_profiling, disable_overwrite) asyncio.run(runner.debug()) + elif module == "viewer": - from streamlit.web import cli as stcli - - current_dir = Path(__file__).resolve().parent.parent - viewer_path = os.path.join(current_dir, "buffer", "viewer.py") - output_dir_abs = os.path.abspath(output_dir) - if output_dir_abs.endswith("/"): - output_dir_abs = output_dir_abs[:-1] - print(f"sqlite:///{output_dir_abs}/experiences.db") - sys.argv = [ - "streamlit", - "run", - viewer_path, - "--server.port", - str(port), - "--server.fileWatcherType", - "none", - "--", - "--db-url", - f"sqlite:///{output_dir_abs}/experiences.db", - "--table", - "debug_buffer", - "--tokenizer", - config.model.model_path, - ] - sys.exit(stcli.main()) - else: - raise ValueError( - f"Only support 'inference_model' and 'workflow' for debugging, got {module}" + from trinity.buffer.viewer import SQLExperienceViewer + + output_dir_abs = os.path.abspath(output_dir).rstrip("/") + + SQLExperienceViewer.run_viewer( + model_path=cfg.model.model_path, + db_url=f"sqlite:///{os.path.join(output_dir_abs, 'debug_buffer.db')}", + table_name="debug_buffer", + port=port, ) -def convert(checkpoint_dir: str, base_model_dir: Optional[str] = None) -> None: - if "global_step_" in checkpoint_dir: - while not os.path.basename(checkpoint_dir).startswith("global_step_"): - checkpoint_dir = os.path.dirname(checkpoint_dir) +@app.command() +def convert( + checkpoint_dir: Annotated[ + str, + typer.Option("--checkpoint-dir", help="The path to the checkpoint directory."), + ], + base_model_dir: Annotated[ + Optional[str], + typer.Option("--base-model-dir", help="The path to the base model."), + ] = None, +) -> None: + """Convert checkpoints to huggingface format.""" + dir_path = checkpoint_dir + if "global_step_" in dir_path: + while not os.path.basename(dir_path).startswith("global_step_"): + dir_path = os.path.dirname(dir_path) converter = Converter(base_model_dir) - converter.convert(checkpoint_dir) + converter.convert(dir_path) def main() -> None: """The main entrypoint.""" - parser = argparse.ArgumentParser() - subparsers = parser.add_subparsers(dest="command", required=True) - - # run command - run_parser = subparsers.add_parser("run", help="Run RFT process.") - run_parser.add_argument("--config", type=str, required=True, help="Path to the config file.") - run_parser.add_argument( - "--plugin-dir", - type=str, - default=None, - help="Path to the directory containing plugin modules.", - ) - run_parser.add_argument( - "--dlc", action="store_true", help="Specify when running in Aliyun PAI DLC." - ) - - # studio command - studio_parser = subparsers.add_parser("studio", help="Run studio.") - studio_parser.add_argument( - "--port", type=int, default=8501, help="The port for Trinity-Studio." - ) - - # debug command - debug_parser = subparsers.add_parser("debug", help="Debug the code.") - debug_parser.add_argument("--config", type=str, help="Path to the config file.") - debug_parser.add_argument( - "--module", - type=str, - choices=["inference_model", "workflow", "viewer"], - help="The module to start debugging, only support 'inference_model', 'workflow' and 'viewer' for now.", - ) - debug_parser.add_argument( - "--plugin-dir", - type=str, - default=None, - help="Path to the directory containing plugin modules.", - ) - debug_parser.add_argument( - "--output-dir", - type=str, - default="debug_output", - help="The output directory for debug files.", - ) - debug_parser.add_argument( - "--disable-overwrite", action="store_true", help="Disable overwriting the output directory." - ) - debug_parser.add_argument( - "--enable-profiling", - action="store_true", - help="Whether to use viztracer for workflow profiling.", - ) - debug_parser.add_argument( - "--output-file", - type=str, - default=None, - help="[DEPRECATED] Please use --output-dir instead.", - ) - debug_parser.add_argument( - "--port", - type=int, - default=8502, - help="The port for Experience Viewer.", - ) - - convert_parser = subparsers.add_parser( - "convert", help="Convert checkpoint to huggingface format." - ) - convert_parser.add_argument( - "--checkpoint-dir", - type=str, - required=True, - help="The path to the checkpoint directory.", - ) - convert_parser.add_argument( - "--base-model-dir", - type=str, - default=None, - help="The path to the base model.", - ) - - args = parser.parse_args() - if args.command == "run": - # TODO: support parse all args from command line - run(args.config, args.dlc, args.plugin_dir) - elif args.command == "studio": - studio(args.port) - elif args.command == "debug": - debug( - args.config, - args.module, - args.output_dir, - args.disable_overwrite, - args.enable_profiling, - args.port, - args.plugin_dir, - ) - elif args.command == "convert": - convert(args.checkpoint_dir, args.base_model_dir) - else: - raise ValueError(f"Unknown command: {args.command}") + app() if __name__ == "__main__": diff --git a/trinity/common/experience.py b/trinity/common/experience.py index 997cf3ab4a..d9d519252e 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -8,10 +8,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union import torch +from torch import Tensor if TYPE_CHECKING: from datasets import Dataset -from torch import Tensor @dataclass diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 676df1bf0c..b9bd37e1ea 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -4,11 +4,10 @@ import copy import socket from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union import httpx import numpy as np -import openai import ray import torch from PIL import Image @@ -20,6 +19,9 @@ from trinity.common.models.utils import get_action_mask_method from trinity.utils.log import get_logger +if TYPE_CHECKING: + import openai + class InferenceModel(ABC): """A model for high performance for rollout inference.""" @@ -499,12 +501,14 @@ async def get_lora_request_async(self) -> Any: async def get_message_token_len(self, messages: List[dict]) -> int: return await self.model.get_message_token_len.remote(messages) - def get_openai_client(self) -> openai.OpenAI: + def get_openai_client(self) -> "openai.OpenAI": """Get the openai client. Returns: openai.OpenAI: The openai client. And `model_path` is added to the client which refers to the model path. """ + import openai + if self.openai_client is not None: setattr(self.openai_client, "model_path", self.model_path) return self.openai_client @@ -558,12 +562,14 @@ def record_chat_completions(*args, **kwargs): setattr(self.openai_client, "model_path", self.model_path) return self.openai_client - def get_openai_async_client(self) -> openai.AsyncOpenAI: + def get_openai_async_client(self) -> "openai.AsyncOpenAI": """Get the async openai client. Returns: openai.AsyncOpenAI: The async openai client. And `model_path` is added to the client which refers to the model path. """ + import openai + if self.openai_async_client is not None: setattr(self.openai_async_client, "model_path", self.model_path) return self.openai_async_client diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index e1bb8fde5d..e2ac18e12a 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -200,6 +200,7 @@ async def run_with_retry( status, exps = await asyncio.wait_for( self.runner.run_task.remote( task=task2run, + batch_id=str(task.batch_id), repeat_times=repeat_times, run_id_base=run_id_base, ), @@ -477,8 +478,9 @@ def schedule(self, tasks: List[Task], batch_id: Union[int, str]) -> None: Args: tasks (`List[Task]`): The tasks to schedule. - batch_id (`Union[int, str]`): The id of provided tasks. It should be an integer or a string - starting with an integer (e.g., 123, "123/my_task") + batch_id (`Union[int, str]`): + The id of provided tasks. In most cases, it should be current step number for + training tasks and "/" for eval tasks. """ if not tasks: return diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 70f712386f..9edd412411 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -56,7 +56,8 @@ def __init__( auxiliary_models: Optional[List[InferenceModel]] = None, runner_id: Optional[int] = None, ) -> None: - self.logger = get_logger(f"{config.explorer.name}_runner_{runner_id}", in_ray_actor=True) + self.name = f"{config.explorer.name}_{runner_id}" + self.logger = get_logger(self.name, in_ray_actor=True) self.config = config self.model = model self.model_wrapper = ModelWrapper( @@ -95,6 +96,11 @@ def __init__( f"Unknown concurrent_mode {self.concurrent_mode}, defaulting to sequential." ) self.concurrent_run_fn = self._sequential_run + self.logger.info( + f"WorkflowRunner [{self.name}]({self.concurrent_mode}) initialized:\n" + f" > rollout model: {self.config.explorer.rollout_model.model_path}" + f" > auxiliary models: {[aux_model_config.model_path for aux_model_config in self.config.explorer.auxiliary_models]}" + ) async def prepare(self) -> None: """Prepare the runner.""" @@ -102,6 +108,7 @@ async def prepare(self) -> None: self.model_wrapper.prepare(), *(aux_model.prepare() for aux_model in self.auxiliary_model_wrappers), ) + self.logger.info(f"WorkflowRunner [{self.name}] is prepared and ready to run tasks.") def is_alive(self): return True @@ -134,7 +141,6 @@ async def _run_task( self, task: Task, repeat_times: int, run_id_base: int ) -> Tuple[List[Experience], List[Dict]]: """Init workflow from the task and run it.""" - if task.workflow.can_repeat: workflow_instance = self._create_workflow_instance(task) workflow_instance.set_repeat_times(repeat_times, run_id_base) @@ -267,6 +273,7 @@ async def get_runner_state(self) -> Dict: async def run_task( self, task: Task, + batch_id: str, repeat_times: int = 1, run_id_base: int = 0, ) -> Tuple[Status, List[Experience]]: @@ -276,6 +283,9 @@ async def run_task( st = time.time() model_version = await self.model_wrapper.model_version_async self.runner_state["model_version"] = model_version + self.logger.info( + f"Starting task: step={batch_id}, model_version={model_version}, repeat_times={repeat_times}, run_id_base={run_id_base}" + ) exps, metrics = await self._run_task(task, repeat_times, run_id_base) assert exps is not None and len(exps) > 0, "An empty experience is generated" # set eid for each experience @@ -364,12 +374,16 @@ async def debug(self) -> None: task = tasks[0] self.logger.info(f"Start debugging task:\n{task.raw_task}") if not self.enable_profiling: - status, exps = await self.run_task(task, 1, 0) + status, exps = await self.run_task( + task=task, batch_id="debug", repeat_times=1, run_id_base=0 + ) else: from viztracer import VizTracer with VizTracer(output_file=self.output_profiling_file): - status, exps = await self.run_task(task, 1, 0) + status, exps = await self.run_task( + task=task, batch_id="debug", repeat_times=1, run_id_base=0 + ) if not status.ok and len(exps) == 0: exps = self.model_wrapper.extract_experience_from_history() self.logger.info(f"Debugging failed, extracting {len(exps)} experiences from history.") diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index e9e5672da9..3ba2605ed6 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -1,11 +1,13 @@ import copy import os import subprocess +import sys import tempfile from typing import List import streamlit as st import yaml +from streamlit.web import cli from trinity.algorithm import ALGORITHM_TYPE from trinity.algorithm.advantage_fn import ADVANTAGE_FN @@ -58,6 +60,20 @@ def __init__(self): st.session_state.is_running = False self.generate_config() + @staticmethod + def run(port: int): + config_manager_path = os.path.abspath(__file__) + sys.argv = [ + "streamlit", + "run", + config_manager_path, + "--server.port", + str(port), + "--server.fileWatcherType", + "none", + ] + sys.exit(cli.main()) + def reset_session_state(self): st.session_state["_init_config_manager"] = True for key, value in CONFIG_GENERATORS.default_config.items(): diff --git a/trinity/manager/synchronizer.py b/trinity/manager/synchronizer.py index 71cebdf4b7..8682e74c76 100644 --- a/trinity/manager/synchronizer.py +++ b/trinity/manager/synchronizer.py @@ -10,10 +10,6 @@ from trinity.common.config import Config from trinity.common.constants import RunningStatus, SyncMethod -from trinity.common.models.utils import ( - get_checkpoint_dir_with_step_num, - load_state_dict, -) from trinity.utils.log import get_logger @@ -94,6 +90,8 @@ async def _find_latest_state_dict(self) -> None: ) async def _find_verl_latest_state_dict(self) -> None: + from trinity.common.models.utils import load_state_dict + default_local_dir = self.config.checkpoint_job_dir local_latest_state_dict_iteration = os.path.join( default_local_dir, "latest_state_dict_iteration.txt" @@ -219,6 +217,11 @@ async def set_model_state_dict_with_step_num( Returns: The updated model version (step number). """ + from trinity.common.models.utils import ( + get_checkpoint_dir_with_step_num, + load_state_dict, + ) + if world_size is not None: # Used when trainer updates the model assert step_num is not None assert self.checkpoint_shard_counter[step_num] < world_size, "World size mismatch!" diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index e225a7ea2d..ade4958d1b 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -268,22 +268,12 @@ def default_args(cls) -> Dict: class SwanlabMonitor(Monitor): - """Monitor with SwanLab. - - This monitor integrates with SwanLab (https://swanlab.cn/) to track experiments. - - Supported monitor_args in config.monitor.monitor_args: - - api_key (Optional[str]): API key for swanlab.login(). If omitted, will read from env - (SWANLAB_API_KEY, SWANLAB_APIKEY, SWANLAB_KEY, SWANLAB_TOKEN) or assume prior CLI login. - - workspace (Optional[str]): Organization/username workspace. - - mode (Optional[str]): "cloud" | "local" | "offline" | "disabled". - - logdir (Optional[str]): Local log directory when in local/offline modes. - - experiment_name (Optional[str]): Explicit experiment name. Defaults to "{name}_{role}". - - description (Optional[str]): Experiment description. - - tags (Optional[List[str]]): Tags to attach. Role and group are appended automatically. - - id (Optional[str]): Resume target run id (21 chars) when using resume modes. - - resume (Optional[Literal['must','allow','never']|bool]): Resume policy. - - reinit (Optional[bool]): Whether to re-init on repeated init() calls. + """Monitor with SwanLab (https://swanlab.cn/). + + Set `SWANLAB_API_KEY` environment variable with your SwanLab API key before using this monitor. + If you're using local deployment of Swanlab, also set `SWANLAB_API_HOST` environment variable. + Pass additional SwanLab initialization arguments via `config.monitor.monitor_args` in the Config, + such as `tags`, `description`, `logdir`, etc. See SwanLab documentation for details. """ def __init__( @@ -293,11 +283,7 @@ def __init__( swanlab is not None ), "swanlab is not installed. Please install it to use SwanlabMonitor." - monitor_args = ( - (config.monitor.monitor_args or {}) - if config and getattr(config, "monitor", None) - else {} - ) + monitor_args = config.monitor.monitor_args or {} # Optional API login via code if provided; otherwise try environment, then rely on prior `swanlab login`. api_key = os.environ.get("SWANLAB_API_KEY") @@ -308,7 +294,7 @@ def __init__( # Best-effort login; continue to init which may still work if already logged in pass else: - raise RuntimeError("Swanlab API key not found in environment variable SWANLAB_API_KEY.") + raise RuntimeError("SWANLAB_API_KEY environment variable not set.") # Compose tags (ensure list and include role/group markers) tags = monitor_args.get("tags") or []