diff --git a/tests/utils/plugin_test.py b/tests/utils/plugin_test.py index 1e91a72689..b1a6c7194c 100644 --- a/tests/utils/plugin_test.py +++ b/tests/utils/plugin_test.py @@ -1,51 +1,92 @@ import os +import shutil import unittest from pathlib import Path +from typing import Type import ray -from tests.tools import TensorBoardParser, get_template_config +from tests.tools import TensorBoardParser, get_checkpoint_path, get_template_config from trinity.common.config import Config -from trinity.common.workflows import WORKFLOWS +from trinity.common.constants import PLUGIN_DIRS_ENV_VAR +from trinity.common.workflows import WORKFLOWS, Workflow from trinity.utils.monitor import MONITOR from trinity.utils.plugin_loader import load_plugins class PluginActor: - def __init__(self, config: Config): + def __init__( + self, + config: Config, + enable_load_plugins: bool = True, + enable_monitor: bool = True, + enable_workflow: bool = True, + ): + if enable_load_plugins: + load_plugins() self.config = config - self.monitor = MONITOR.get("my_monitor")( - project=self.config.project, - group=self.config.group, - name=self.config.name, - role=self.config.explorer.name, - config=config, - ) + if enable_monitor: + self.monitor = MONITOR.get("my_monitor")( + project=self.config.project, + group=self.config.group, + name=self.config.name, + role=self.config.explorer.name, + config=config, + ) + else: + self.monitor = None + if enable_workflow: + workflow = WORKFLOWS.get("my_workflow") + assert workflow is not None, "Workflow 'my_workflow' not found in registry" - def run(self): - my_plugin_cls = WORKFLOWS.get("my_workflow") - self.monitor.log({"rollout": 2}, step=1, commit=True) - return my_plugin_cls(task=None, model=None).run() + def run(self, workflow_cls=Type[Workflow]): + if self.monitor: + self.monitor.log({"rollout": 2}, step=1, commit=True) + return workflow_cls(task=None, model=None).run() class TestPluginLoader(unittest.TestCase): - def test_load_plugins(self): - ray.init(ignore_reinit_error=True) - config = get_template_config() - my_plugin_cls = WORKFLOWS.get("my_workflow") - self.assertIsNone(my_plugin_cls) - load_plugins(Path(__file__).resolve().parent / "plugins") - my_plugin_cls = WORKFLOWS.get("my_workflow") - self.assertIsNotNone(my_plugin_cls) - my_plugin = my_plugin_cls(task=None, model=None, auxiliary_models=None) + def test_load_plugins_local(self): + my_workflow_cls = WORKFLOWS.get("my_workflow") + self.assertIsNone(my_workflow_cls) + os.environ[PLUGIN_DIRS_ENV_VAR] = str(Path(__file__).resolve().parent / "plugins") + try: + load_plugins() + except KeyError: + # already registered in next test + pass + my_workflow_cls = WORKFLOWS.get("my_workflow") + self.assertIsNotNone(my_workflow_cls) + my_plugin = my_workflow_cls(task=None, model=None, auxiliary_models=None) self.assertTrue(my_plugin.__module__.startswith("trinity.plugins")) res = my_plugin.run() self.assertEqual(res[0], "Hello world") self.assertEqual(res[1], "Hi") - # Remote Actor test + def test_load_plugins_remote(self): + os.environ[PLUGIN_DIRS_ENV_VAR] = str(Path(__file__).resolve().parent / "plugins") + try: + load_plugins() + except KeyError: + # already registered in previous test + pass + config = self.config + ray.init( + ignore_reinit_error=True, + runtime_env={ + "env_vars": {PLUGIN_DIRS_ENV_VAR: str(Path(__file__).resolve().parent / "plugins")} + }, + ) + my_workflow_cls = WORKFLOWS.get("my_workflow") + # disable plugin and use custom class from registry + remote_plugin = ray.remote(PluginActor).remote(config, enable_load_plugins=False) + remote_plugin.run.remote(my_workflow_cls) + with self.assertRaises(ray.exceptions.ActorDiedError): + ray.get(remote_plugin.__ray_ready__.remote()) + + # enable plugin remote_plugin = ray.remote(PluginActor).remote(config) - remote_res = ray.get(remote_plugin.run.remote()) + remote_res = ray.get(remote_plugin.run.remote(my_workflow_cls)) self.assertEqual(remote_res[0], "Hello world") self.assertEqual(remote_res[1], "Hi") @@ -53,4 +94,25 @@ def test_load_plugins(self): parser = TensorBoardParser(os.path.join(config.monitor.cache_dir, "tensorboard")) rollout_cnt = parser.metric_values("rollout") self.assertEqual(rollout_cnt, [2]) + + def test_passing_custom_class(self): + # disable plugin and pass custom class directly + os.environ[PLUGIN_DIRS_ENV_VAR] = str(Path(__file__).resolve().parent / "plugins") + try: + load_plugins() + except KeyError: + # already registered in previous test + pass + my_workflow_cls = WORKFLOWS.get("my_workflow") + remote_plugin = ray.remote(PluginActor).remote( + self.config, enable_load_plugins=False, enable_monitor=False, enable_workflow=False + ) + remote_res = ray.get(remote_plugin.run.remote(my_workflow_cls)) + self.assertEqual(remote_res[0], "Hello world") ray.shutdown(_exiting_interpreter=True) + + def setUp(self): + self.config = get_template_config() + self.config.checkpoint_root_dir = get_checkpoint_path() + self.config.check_and_update() + shutil.rmtree(self.config.monitor.cache_dir, ignore_errors=True) diff --git a/trinity/__init__.py b/trinity/__init__.py index 95ddbfd44b..63f1db4fdc 100644 --- a/trinity/__init__.py +++ b/trinity/__init__.py @@ -2,7 +2,3 @@ """Trinity-RFT (Reinforcement Fine-Tuning)""" __version__ = "0.2.1.dev0" - -from trinity.utils.plugin_loader import load_plugins - -load_plugins() diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index e36a68fd4a..e61a59d037 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -9,7 +9,7 @@ import ray from trinity.common.config import Config, load_config -from trinity.common.constants import DataProcessorPipelineType +from trinity.common.constants import PLUGIN_DIRS_ENV_VAR, DataProcessorPipelineType from trinity.data.utils import ( activate_data_processor, stop_data_processor, @@ -186,9 +186,8 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): f"{data_processor_config.data_processor_url}/{DataProcessorPipelineType.EXPERIENCE.value}", config_path, ) - envs = os.environ.copy() - all_plugin_dirs = [d for d in (plugin_dir, envs.get("PLUGIN_DIRS")) if d] - envs["PLUGIN_DIRS"] = os.pathsep.join(all_plugin_dirs) + + envs = {PLUGIN_DIRS_ENV_VAR: plugin_dir} if dlc: from trinity.utils.dlc_utils import setup_ray_cluster diff --git a/trinity/common/constants.py b/trinity/common/constants.py index 392f2dc553..e4428ed16b 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -13,6 +13,8 @@ ROLLOUT_WEIGHT_SYNC_GROUP_NAME = "rollout_weight_sync" +PLUGIN_DIRS_ENV_VAR = "TRINITY_PLUGIN_DIRS" + # enumerate types diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 6533b2ebcd..fd1a3ab47e 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -28,6 +28,7 @@ from trinity.manager.synchronizer import Synchronizer from trinity.utils.log import get_logger from trinity.utils.monitor import MONITOR, gather_metrics +from trinity.utils.plugin_loader import load_plugins class Explorer: @@ -35,6 +36,7 @@ class Explorer: def __init__(self, config: Config): self.logger = get_logger(__name__) + load_plugins() self.cache = CacheManager(config) explorer_meta = self.cache.load_explorer() self.explore_step_num = explorer_meta.get("latest_iteration", 0) diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index c529eb759c..4a4172feef 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -1,7 +1,6 @@ """Scheduler for rollout tasks.""" import asyncio -import os import re import time import traceback @@ -50,15 +49,13 @@ def __init__( self.runner = self._create_runner() def _create_runner(self): - envs = os.environ.copy() - envs.update(self.config.explorer.env_vars) return ( ray.remote(WorkflowRunner) .options( namespace=self.namespace, scheduling_strategy="SPREAD", runtime_env={ - "env_vars": envs, + "env_vars": self.config.explorer.env_vars, }, ) .remote(self.config, self.rollout_model, self.auxiliary_models, self.runner_id) diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index daf5a03f8f..159e3cc4d6 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -33,6 +33,7 @@ def __init__( auxiliary_models: Optional[List[InferenceModel]] = None, runner_id: Optional[int] = None, ) -> None: + self.logger = get_logger(__name__) self.config = config self.experience_buffer = get_buffer_writer( self.config.buffer.explorer_output, # type: ignore @@ -52,7 +53,6 @@ def __init__( "vllm_async", ).get_openai_client() self.auxiliary_models.append(api_client) - self.logger = get_logger(__name__) self.workflow_instance = None self.runner_id = runner_id self.return_experiences = self.config.explorer.collect_experiences diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 9a25f8ca57..0326c30d7a 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -20,6 +20,7 @@ from trinity.manager.synchronizer import Synchronizer from trinity.utils.log import get_logger from trinity.utils.monitor import MONITOR +from trinity.utils.plugin_loader import load_plugins class Trainer: @@ -28,6 +29,7 @@ class Trainer: def __init__(self, config: Config) -> None: self.config = config self.logger = get_logger(__name__) + load_plugins() self.synchronizer = Synchronizer.get_actor(config) self.engine = get_trainer_wrapper(config) self.last_trainer_sync_step = 0 diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index e6d5a91458..4eafcb95a9 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -35,6 +35,7 @@ from trinity.algorithm.kl_fn.kl_fn import DummyKLFn from trinity.algorithm.utils import prefix_metrics from trinity.common.config import AlgorithmConfig +from trinity.utils.plugin_loader import load_plugins __all__ = ["DataParallelPPOActor"] @@ -48,7 +49,7 @@ def __init__( ): """When optimizer is None, it is Reference Policy""" super().__init__(config, actor_module, actor_optimizer) - + load_plugins() self.policy_loss_fn = None self.kl_loss_fn = None self.entropy_loss_fn = None diff --git a/trinity/utils/plugin_loader.py b/trinity/utils/plugin_loader.py index eb617e496e..d67b43c332 100644 --- a/trinity/utils/plugin_loader.py +++ b/trinity/utils/plugin_loader.py @@ -7,35 +7,37 @@ from pathlib import Path from typing import List, Union +from trinity.common.constants import PLUGIN_DIRS_ENV_VAR from trinity.utils.log import get_logger logger = get_logger(__name__) -loaded_dirs = set() +def load_plugins() -> None: + """ + Load plugin modules from the default plugin directory or directories specified in the environment variable. + If the environment variable `PLUGIN_DIRS_ENV_VAR` is not set, it defaults to `trinity/plugins`. + """ + plugin_dirs = os.environ.get(PLUGIN_DIRS_ENV_VAR, "").split(os.pathsep) + if not plugin_dirs or plugin_dirs == [""]: + plugin_dirs = [str(Path(__file__).parent.parent / "plugins")] + + load_plugin_from_dirs(plugin_dirs) -def load_plugins(plugin_dirs: Union[str, List[str]] = None) -> None: + +def load_plugin_from_dirs(plugin_dirs: Union[str, List[str]]) -> None: """ Load plugin modules from a directory. """ - global loaded_dirs - if plugin_dirs is None: - plugin_dirs = [Path(__file__).parent.parent / "plugins"] - for plugin_dir in os.environ.get("PLUGIN_DIRS", "").split(os.pathsep): - plugin_dir = plugin_dir.strip() - if plugin_dir: - plugin_dirs.append(plugin_dir) if not isinstance(plugin_dirs, list): plugin_dirs = [plugin_dirs] + plugin_dirs = set(plugin_dirs) for plugin_dir in plugin_dirs: - if plugin_dir in loaded_dirs: - continue - loaded_dirs.add(plugin_dir) if not os.path.exists(plugin_dir): - logger.error(f"--plugin-dir [{plugin_dir}] does not exist.") + logger.error(f"plugin-dir [{plugin_dir}] does not exist.") continue if not os.path.isdir(plugin_dir): - logger.error(f"--plugin-dir [{plugin_dir}] is not a directory.") + logger.error(f"plugin-dir [{plugin_dir}] is not a directory.") continue logger.info(f"Loading plugin modules from [{plugin_dir}]...")