diff --git a/pyproject.toml b/pyproject.toml index 0a01b05f9d..68b17b5dcc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ "jsonlines", "sortedcontainers", "word2number", + "transformers<4.54.0", # TODO: remove when https://github.com/vllm-project/vllm-ascend/issues/2046 is fixed ] [project.scripts] diff --git a/tests/utils/plugin_test.py b/tests/utils/plugin_test.py index e35e6eab0d..1e91a72689 100644 --- a/tests/utils/plugin_test.py +++ b/tests/utils/plugin_test.py @@ -1,22 +1,37 @@ +import os import unittest from pathlib import Path import ray +from tests.tools import TensorBoardParser, get_template_config +from trinity.common.config import Config from trinity.common.workflows import WORKFLOWS +from trinity.utils.monitor import MONITOR from trinity.utils.plugin_loader import load_plugins -@ray.remote class PluginActor: + def __init__(self, config: Config): + self.config = config + self.monitor = MONITOR.get("my_monitor")( + project=self.config.project, + group=self.config.group, + name=self.config.name, + role=self.config.explorer.name, + config=config, + ) + def run(self): my_plugin_cls = WORKFLOWS.get("my_workflow") + self.monitor.log({"rollout": 2}, step=1, commit=True) return my_plugin_cls(task=None, model=None).run() class TestPluginLoader(unittest.TestCase): def test_load_plugins(self): ray.init(ignore_reinit_error=True) + config = get_template_config() my_plugin_cls = WORKFLOWS.get("my_workflow") self.assertIsNone(my_plugin_cls) load_plugins(Path(__file__).resolve().parent / "plugins") @@ -27,8 +42,15 @@ def test_load_plugins(self): res = my_plugin.run() self.assertEqual(res[0], "Hello world") self.assertEqual(res[1], "Hi") - remote_plugin = PluginActor.remote() + + # Remote Actor test + remote_plugin = ray.remote(PluginActor).remote(config) remote_res = ray.get(remote_plugin.run.remote()) self.assertEqual(remote_res[0], "Hello world") self.assertEqual(remote_res[1], "Hi") + + # test custom monitor + parser = TensorBoardParser(os.path.join(config.monitor.cache_dir, "tensorboard")) + rollout_cnt = parser.metric_values("rollout") + self.assertEqual(rollout_cnt, [2]) ray.shutdown(_exiting_interpreter=True) diff --git a/tests/utils/plugins/my_monitor.py b/tests/utils/plugins/my_monitor.py new file mode 100644 index 0000000000..41a6015ced --- /dev/null +++ b/tests/utils/plugins/my_monitor.py @@ -0,0 +1,30 @@ +import os + +from torch.utils.tensorboard import SummaryWriter + +from trinity.common.config import Config +from trinity.utils.log import get_logger +from trinity.utils.monitor import MONITOR, Monitor + + +@MONITOR.register_module("my_monitor") +class MyMonitor(Monitor): + def __init__( + self, project: str, group: str, name: str, role: str, config: Config = None + ) -> None: + self.tensorboard_dir = os.path.join(config.monitor.cache_dir, "tensorboard", role) + os.makedirs(self.tensorboard_dir, exist_ok=True) + self.logger = SummaryWriter(self.tensorboard_dir) + self.console_logger = get_logger(__name__) + + def log_table(self, table_name: str, experiences_table, step: int): + pass + + def log(self, data: dict, step: int, commit: bool = False) -> None: + """Log metrics.""" + for key in data: + self.logger.add_scalar(key, data[key], step) + self.console_logger.info(f"Step {step}: {data}") + + def close(self) -> None: + self.logger.close() diff --git a/trinity/__init__.py b/trinity/__init__.py index 63f1db4fdc..95ddbfd44b 100644 --- a/trinity/__init__.py +++ b/trinity/__init__.py @@ -2,3 +2,7 @@ """Trinity-RFT (Reinforcement Fine-Tuning)""" __version__ = "0.2.1.dev0" + +from trinity.utils.plugin_loader import load_plugins + +load_plugins() diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 50bcf5b8c0..e36a68fd4a 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -18,7 +18,6 @@ from trinity.explorer.explorer import Explorer from trinity.trainer.trainer import Trainer from trinity.utils.log import get_logger -from trinity.utils.plugin_loader import load_plugins logger = get_logger(__name__) @@ -159,7 +158,6 @@ def both(config: Config) -> None: def run(config_path: str, dlc: bool = False, plugin_dir: str = None): - load_plugins(plugin_dir) config = load_config(config_path) config.check_and_update() pprint(config) @@ -188,16 +186,21 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): f"{data_processor_config.data_processor_url}/{DataProcessorPipelineType.EXPERIENCE.value}", config_path, ) + envs = os.environ.copy() + all_plugin_dirs = [d for d in (plugin_dir, envs.get("PLUGIN_DIRS")) if d] + envs["PLUGIN_DIRS"] = os.pathsep.join(all_plugin_dirs) if dlc: from trinity.utils.dlc_utils import setup_ray_cluster - setup_ray_cluster(namespace=config.ray_namespace) + setup_ray_cluster(namespace=config.ray_namespace, envs=envs) else: from trinity.utils.dlc_utils import is_running if not is_running: raise RuntimeError("Ray is not running, please start it by `ray start --head`.") - ray.init(namespace=config.ray_namespace, ignore_reinit_error=True) + ray.init( + namespace=config.ray_namespace, ignore_reinit_error=True, runtime_env={"env_vars": envs} + ) try: if config.mode == "explore": explore(config) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index ba422bf5dd..0cda95c488 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -76,7 +76,7 @@ def __init__(self, config: Config): self.logger.info("Finished initializing Explorer.") self.collect_experiences = self.config.explorer.collect_experiences self.generated_experience_cnt = 0 - if self.collect_experiences: + if self.collect_experiences and self.config.mode != "bench": assert ( self.experience_buffer is not None ), "Experience buffer is required when collect_experiences is True." diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index 4a4172feef..c529eb759c 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -1,6 +1,7 @@ """Scheduler for rollout tasks.""" import asyncio +import os import re import time import traceback @@ -49,13 +50,15 @@ def __init__( self.runner = self._create_runner() def _create_runner(self): + envs = os.environ.copy() + envs.update(self.config.explorer.env_vars) return ( ray.remote(WorkflowRunner) .options( namespace=self.namespace, scheduling_strategy="SPREAD", runtime_env={ - "env_vars": self.config.explorer.env_vars, + "env_vars": envs, }, ) .remote(self.config, self.rollout_model, self.auxiliary_models, self.runner_id) diff --git a/trinity/utils/dlc_utils.py b/trinity/utils/dlc_utils.py index 380e7d05b9..5bf61c5f06 100644 --- a/trinity/utils/dlc_utils.py +++ b/trinity/utils/dlc_utils.py @@ -64,14 +64,14 @@ def wait_for_ray_worker_nodes(world_size: int) -> None: time.sleep(1) -def setup_ray_cluster(namespace: str): +def setup_ray_cluster(namespace: str, envs: dict): env_vars = get_dlc_env_vars() is_master = env_vars["RANK"] == 0 if is_running(): # reuse existing ray cluster if is_master: - ray.init(namespace=namespace, ignore_reinit_error=True) + ray.init(namespace=namespace, ignore_reinit_error=True, runtime_env={"env_vars": envs}) else: if is_master: cmd = f"ray start --head --port={env_vars['MASTER_PORT']} --node-ip-address={env_vars['MASTER_ADDR']}" @@ -91,6 +91,7 @@ def setup_ray_cluster(namespace: str): address=f"{env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}", namespace=namespace, ignore_reinit_error=True, + runtime_env={"env_vars": envs}, ) if is_master: # master wait for worker nodes to join diff --git a/trinity/utils/plugin_loader.py b/trinity/utils/plugin_loader.py index 3a4e935d7d..eb617e496e 100644 --- a/trinity/utils/plugin_loader.py +++ b/trinity/utils/plugin_loader.py @@ -5,32 +5,46 @@ import shutil import sys from pathlib import Path +from typing import List, Union from trinity.utils.log import get_logger logger = get_logger(__name__) +loaded_dirs = set() -def load_plugins(plugin_dir: str) -> None: + +def load_plugins(plugin_dirs: Union[str, List[str]] = None) -> None: """ Load plugin modules from a directory. """ - if plugin_dir is None: - plugin_dir = Path(__file__).parent.parent / "plugins" - if not os.path.exists(plugin_dir): - logger.error(f"--plugin-dir [{plugin_dir}] does not exist.") - return None - if not os.path.isdir(plugin_dir): - logger.error(f"--plugin-dir [{plugin_dir}] is not a directory.") - return None - - logger.info(f"Loading plugin modules from [{plugin_dir}]...") - for file in Path(plugin_dir).glob("*.py"): - if file.name.startswith("__"): + global loaded_dirs + if plugin_dirs is None: + plugin_dirs = [Path(__file__).parent.parent / "plugins"] + for plugin_dir in os.environ.get("PLUGIN_DIRS", "").split(os.pathsep): + plugin_dir = plugin_dir.strip() + if plugin_dir: + plugin_dirs.append(plugin_dir) + if not isinstance(plugin_dirs, list): + plugin_dirs = [plugin_dirs] + for plugin_dir in plugin_dirs: + if plugin_dir in loaded_dirs: + continue + loaded_dirs.add(plugin_dir) + if not os.path.exists(plugin_dir): + logger.error(f"--plugin-dir [{plugin_dir}] does not exist.") continue - logger.info(f"Loading plugin modules from [{file}]...") - # load modules from file - load_from_file(os.path.join(plugin_dir, file)) + if not os.path.isdir(plugin_dir): + logger.error(f"--plugin-dir [{plugin_dir}] is not a directory.") + continue + + logger.info(f"Loading plugin modules from [{plugin_dir}]...") + for file in Path(plugin_dir).glob("*.py"): + if file.name.startswith("__"): + continue + logger.info(f"Loading plugin modules from [{file}]...") + # load modules from file + load_from_file(os.path.join(plugin_dir, file)) def load_from_file(file_path: str):