From 18990d4115d2438b3fcfbba223ebbe762e616bf7 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Sat, 16 Aug 2025 13:05:37 +0800 Subject: [PATCH 1/5] bug fix in load_plugins --- trinity/__init__.py | 4 ++++ trinity/cli/launcher.py | 11 +++++++-- trinity/explorer/scheduler.py | 5 +++- trinity/utils/dlc_utils.py | 5 ++-- trinity/utils/plugin_loader.py | 44 +++++++++++++++++++++------------- 5 files changed, 48 insertions(+), 21 deletions(-) diff --git a/trinity/__init__.py b/trinity/__init__.py index 63f1db4fdc..95ddbfd44b 100644 --- a/trinity/__init__.py +++ b/trinity/__init__.py @@ -2,3 +2,7 @@ """Trinity-RFT (Reinforcement Fine-Tuning)""" __version__ = "0.2.1.dev0" + +from trinity.utils.plugin_loader import load_plugins + +load_plugins() diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 50bcf5b8c0..68fc2bf11b 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -188,16 +188,23 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): f"{data_processor_config.data_processor_url}/{DataProcessorPipelineType.EXPERIENCE.value}", config_path, ) + envs = os.environ.copy() + if "PLUGIN_DIRS" in envs: + envs["PLUGIN_DIRS"] = os.pathsep.join(plugin_dir, envs["PLUGIN_DIRS"]) + else: + envs["PLUGIN_DIRS"] = plugin_dir if dlc: from trinity.utils.dlc_utils import setup_ray_cluster - setup_ray_cluster(namespace=config.ray_namespace) + setup_ray_cluster(namespace=config.ray_namespace, envs=envs) else: from trinity.utils.dlc_utils import is_running if not is_running: raise RuntimeError("Ray is not running, please start it by `ray start --head`.") - ray.init(namespace=config.ray_namespace, ignore_reinit_error=True) + ray.init( + namespace=config.ray_namespace, ignore_reinit_error=True, runtime_env={"env_vars": envs} + ) try: if config.mode == "explore": explore(config) diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index 4a4172feef..c529eb759c 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -1,6 +1,7 @@ """Scheduler for rollout tasks.""" import asyncio +import os import re import time import traceback @@ -49,13 +50,15 @@ def __init__( self.runner = self._create_runner() def _create_runner(self): + envs = os.environ.copy() + envs.update(self.config.explorer.env_vars) return ( ray.remote(WorkflowRunner) .options( namespace=self.namespace, scheduling_strategy="SPREAD", runtime_env={ - "env_vars": self.config.explorer.env_vars, + "env_vars": envs, }, ) .remote(self.config, self.rollout_model, self.auxiliary_models, self.runner_id) diff --git a/trinity/utils/dlc_utils.py b/trinity/utils/dlc_utils.py index 380e7d05b9..5bf61c5f06 100644 --- a/trinity/utils/dlc_utils.py +++ b/trinity/utils/dlc_utils.py @@ -64,14 +64,14 @@ def wait_for_ray_worker_nodes(world_size: int) -> None: time.sleep(1) -def setup_ray_cluster(namespace: str): +def setup_ray_cluster(namespace: str, envs: dict): env_vars = get_dlc_env_vars() is_master = env_vars["RANK"] == 0 if is_running(): # reuse existing ray cluster if is_master: - ray.init(namespace=namespace, ignore_reinit_error=True) + ray.init(namespace=namespace, ignore_reinit_error=True, runtime_env={"env_vars": envs}) else: if is_master: cmd = f"ray start --head --port={env_vars['MASTER_PORT']} --node-ip-address={env_vars['MASTER_ADDR']}" @@ -91,6 +91,7 @@ def setup_ray_cluster(namespace: str): address=f"{env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}", namespace=namespace, ignore_reinit_error=True, + runtime_env={"env_vars": envs}, ) if is_master: # master wait for worker nodes to join diff --git a/trinity/utils/plugin_loader.py b/trinity/utils/plugin_loader.py index 3a4e935d7d..59fa580a4c 100644 --- a/trinity/utils/plugin_loader.py +++ b/trinity/utils/plugin_loader.py @@ -10,27 +10,39 @@ logger = get_logger(__name__) +loaded_dirs = [] -def load_plugins(plugin_dir: str) -> None: + +def load_plugins(plugin_dirs: str = None) -> None: """ Load plugin modules from a directory. """ - if plugin_dir is None: - plugin_dir = Path(__file__).parent.parent / "plugins" - if not os.path.exists(plugin_dir): - logger.error(f"--plugin-dir [{plugin_dir}] does not exist.") - return None - if not os.path.isdir(plugin_dir): - logger.error(f"--plugin-dir [{plugin_dir}] is not a directory.") - return None - - logger.info(f"Loading plugin modules from [{plugin_dir}]...") - for file in Path(plugin_dir).glob("*.py"): - if file.name.startswith("__"): + global loaded_dirs + if plugin_dirs is None: + plugin_dirs = Path(__file__).parent.parent / "plugins" + if os.environ.get("PLUGIN_DIRS", None): + plugin_dirs = os.pathsep.join(plugin_dirs, os.environ["PLUGIN_DIRS"]) + for plugin_dir in plugin_dirs.split(os.pathsep): + plugin_dir = plugin_dir.strip() + if not plugin_dir: + continue + if plugin_dir in loaded_dirs: + continue + if not os.path.exists(plugin_dir): + logger.error(f"--plugin-dir [{plugin_dir}] does not exist.") continue - logger.info(f"Loading plugin modules from [{file}]...") - # load modules from file - load_from_file(os.path.join(plugin_dir, file)) + if not os.path.isdir(plugin_dir): + logger.error(f"--plugin-dir [{plugin_dir}] is not a directory.") + continue + + logger.info(f"Loading plugin modules from [{plugin_dir}]...") + loaded_dirs.append(plugin_dir) + for file in Path(plugin_dir).glob("*.py"): + if file.name.startswith("__"): + continue + logger.info(f"Loading plugin modules from [{file}]...") + # load modules from file + load_from_file(os.path.join(plugin_dir, file)) def load_from_file(file_path: str): From 4b49ba243c2751d90d237ef99f846f57d4dabd8e Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Mon, 18 Aug 2025 11:15:30 +0800 Subject: [PATCH 2/5] bug fix in plugin_loader and explorer --- trinity/cli/launcher.py | 6 ++---- trinity/explorer/explorer.py | 2 +- trinity/utils/plugin_loader.py | 17 ++++++++--------- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 68fc2bf11b..b5c256699b 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -18,7 +18,6 @@ from trinity.explorer.explorer import Explorer from trinity.trainer.trainer import Trainer from trinity.utils.log import get_logger -from trinity.utils.plugin_loader import load_plugins logger = get_logger(__name__) @@ -159,7 +158,6 @@ def both(config: Config) -> None: def run(config_path: str, dlc: bool = False, plugin_dir: str = None): - load_plugins(plugin_dir) config = load_config(config_path) config.check_and_update() pprint(config) @@ -190,9 +188,9 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): ) envs = os.environ.copy() if "PLUGIN_DIRS" in envs: - envs["PLUGIN_DIRS"] = os.pathsep.join(plugin_dir, envs["PLUGIN_DIRS"]) + envs["PLUGIN_DIRS"] = os.pathsep.join([plugin_dir, envs["PLUGIN_DIRS"]]) else: - envs["PLUGIN_DIRS"] = plugin_dir + envs["PLUGIN_DIRS"] = plugin_dir or "" if dlc: from trinity.utils.dlc_utils import setup_ray_cluster diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index ba422bf5dd..0cda95c488 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -76,7 +76,7 @@ def __init__(self, config: Config): self.logger.info("Finished initializing Explorer.") self.collect_experiences = self.config.explorer.collect_experiences self.generated_experience_cnt = 0 - if self.collect_experiences: + if self.collect_experiences and self.config.mode != "bench": assert ( self.experience_buffer is not None ), "Experience buffer is required when collect_experiences is True." diff --git a/trinity/utils/plugin_loader.py b/trinity/utils/plugin_loader.py index 59fa580a4c..c11dc36a0c 100644 --- a/trinity/utils/plugin_loader.py +++ b/trinity/utils/plugin_loader.py @@ -10,7 +10,7 @@ logger = get_logger(__name__) -loaded_dirs = [] +loaded_dirs = set() def load_plugins(plugin_dirs: str = None) -> None: @@ -19,15 +19,15 @@ def load_plugins(plugin_dirs: str = None) -> None: """ global loaded_dirs if plugin_dirs is None: - plugin_dirs = Path(__file__).parent.parent / "plugins" - if os.environ.get("PLUGIN_DIRS", None): - plugin_dirs = os.pathsep.join(plugin_dirs, os.environ["PLUGIN_DIRS"]) - for plugin_dir in plugin_dirs.split(os.pathsep): - plugin_dir = plugin_dir.strip() - if not plugin_dir: - continue + plugin_dirs = [Path(__file__).parent.parent / "plugins"] + for plugin_dir in os.environ.get("PLUGIN_DIRS", "").split(os.pathsep): + plugin_dir = plugin_dir.strip() + if plugin_dir: + plugin_dirs.append(plugin_dir) + for plugin_dir in plugin_dirs: if plugin_dir in loaded_dirs: continue + loaded_dirs.add(plugin_dir) if not os.path.exists(plugin_dir): logger.error(f"--plugin-dir [{plugin_dir}] does not exist.") continue @@ -36,7 +36,6 @@ def load_plugins(plugin_dirs: str = None) -> None: continue logger.info(f"Loading plugin modules from [{plugin_dir}]...") - loaded_dirs.append(plugin_dir) for file in Path(plugin_dir).glob("*.py"): if file.name.startswith("__"): continue From 7537bf9b2384b482a95180ebf926ce8bc659bd39 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Mon, 18 Aug 2025 15:02:07 +0800 Subject: [PATCH 3/5] add custom monitor test && fix in pyproject.toml --- pyproject.toml | 1 + tests/utils/plugin_test.py | 26 ++++++++++++++++++++++++-- tests/utils/plugins/my_monitor.py | 30 ++++++++++++++++++++++++++++++ trinity/utils/plugin_loader.py | 5 ++++- 4 files changed, 59 insertions(+), 3 deletions(-) create mode 100644 tests/utils/plugins/my_monitor.py diff --git a/pyproject.toml b/pyproject.toml index 0a01b05f9d..f9c891e22a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ "jsonlines", "sortedcontainers", "word2number", + "transformers<4.54.0", ] [project.scripts] diff --git a/tests/utils/plugin_test.py b/tests/utils/plugin_test.py index e35e6eab0d..1e91a72689 100644 --- a/tests/utils/plugin_test.py +++ b/tests/utils/plugin_test.py @@ -1,22 +1,37 @@ +import os import unittest from pathlib import Path import ray +from tests.tools import TensorBoardParser, get_template_config +from trinity.common.config import Config from trinity.common.workflows import WORKFLOWS +from trinity.utils.monitor import MONITOR from trinity.utils.plugin_loader import load_plugins -@ray.remote class PluginActor: + def __init__(self, config: Config): + self.config = config + self.monitor = MONITOR.get("my_monitor")( + project=self.config.project, + group=self.config.group, + name=self.config.name, + role=self.config.explorer.name, + config=config, + ) + def run(self): my_plugin_cls = WORKFLOWS.get("my_workflow") + self.monitor.log({"rollout": 2}, step=1, commit=True) return my_plugin_cls(task=None, model=None).run() class TestPluginLoader(unittest.TestCase): def test_load_plugins(self): ray.init(ignore_reinit_error=True) + config = get_template_config() my_plugin_cls = WORKFLOWS.get("my_workflow") self.assertIsNone(my_plugin_cls) load_plugins(Path(__file__).resolve().parent / "plugins") @@ -27,8 +42,15 @@ def test_load_plugins(self): res = my_plugin.run() self.assertEqual(res[0], "Hello world") self.assertEqual(res[1], "Hi") - remote_plugin = PluginActor.remote() + + # Remote Actor test + remote_plugin = ray.remote(PluginActor).remote(config) remote_res = ray.get(remote_plugin.run.remote()) self.assertEqual(remote_res[0], "Hello world") self.assertEqual(remote_res[1], "Hi") + + # test custom monitor + parser = TensorBoardParser(os.path.join(config.monitor.cache_dir, "tensorboard")) + rollout_cnt = parser.metric_values("rollout") + self.assertEqual(rollout_cnt, [2]) ray.shutdown(_exiting_interpreter=True) diff --git a/tests/utils/plugins/my_monitor.py b/tests/utils/plugins/my_monitor.py new file mode 100644 index 0000000000..41a6015ced --- /dev/null +++ b/tests/utils/plugins/my_monitor.py @@ -0,0 +1,30 @@ +import os + +from torch.utils.tensorboard import SummaryWriter + +from trinity.common.config import Config +from trinity.utils.log import get_logger +from trinity.utils.monitor import MONITOR, Monitor + + +@MONITOR.register_module("my_monitor") +class MyMonitor(Monitor): + def __init__( + self, project: str, group: str, name: str, role: str, config: Config = None + ) -> None: + self.tensorboard_dir = os.path.join(config.monitor.cache_dir, "tensorboard", role) + os.makedirs(self.tensorboard_dir, exist_ok=True) + self.logger = SummaryWriter(self.tensorboard_dir) + self.console_logger = get_logger(__name__) + + def log_table(self, table_name: str, experiences_table, step: int): + pass + + def log(self, data: dict, step: int, commit: bool = False) -> None: + """Log metrics.""" + for key in data: + self.logger.add_scalar(key, data[key], step) + self.console_logger.info(f"Step {step}: {data}") + + def close(self) -> None: + self.logger.close() diff --git a/trinity/utils/plugin_loader.py b/trinity/utils/plugin_loader.py index c11dc36a0c..eb617e496e 100644 --- a/trinity/utils/plugin_loader.py +++ b/trinity/utils/plugin_loader.py @@ -5,6 +5,7 @@ import shutil import sys from pathlib import Path +from typing import List, Union from trinity.utils.log import get_logger @@ -13,7 +14,7 @@ loaded_dirs = set() -def load_plugins(plugin_dirs: str = None) -> None: +def load_plugins(plugin_dirs: Union[str, List[str]] = None) -> None: """ Load plugin modules from a directory. """ @@ -24,6 +25,8 @@ def load_plugins(plugin_dirs: str = None) -> None: plugin_dir = plugin_dir.strip() if plugin_dir: plugin_dirs.append(plugin_dir) + if not isinstance(plugin_dirs, list): + plugin_dirs = [plugin_dirs] for plugin_dir in plugin_dirs: if plugin_dir in loaded_dirs: continue From f6be1da8799cf82ad40c87dd8f54bc3fc8531434 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Mon, 18 Aug 2025 15:04:31 +0800 Subject: [PATCH 4/5] apply suggestions from gemini --- trinity/cli/launcher.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index b5c256699b..e36a68fd4a 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -187,10 +187,8 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): config_path, ) envs = os.environ.copy() - if "PLUGIN_DIRS" in envs: - envs["PLUGIN_DIRS"] = os.pathsep.join([plugin_dir, envs["PLUGIN_DIRS"]]) - else: - envs["PLUGIN_DIRS"] = plugin_dir or "" + all_plugin_dirs = [d for d in (plugin_dir, envs.get("PLUGIN_DIRS")) if d] + envs["PLUGIN_DIRS"] = os.pathsep.join(all_plugin_dirs) if dlc: from trinity.utils.dlc_utils import setup_ray_cluster From a43465e810ddd96646d5a670b7194757dc591536 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Mon, 18 Aug 2025 15:12:23 +0800 Subject: [PATCH 5/5] doc fix --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f9c891e22a..68b17b5dcc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ dependencies = [ "jsonlines", "sortedcontainers", "word2number", - "transformers<4.54.0", + "transformers<4.54.0", # TODO: remove when https://github.com/vllm-project/vllm-ascend/issues/2046 is fixed ] [project.scripts]