Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 87 additions & 25 deletions tests/utils/plugin_test.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,118 @@
import os
import shutil
import unittest
from pathlib import Path
from typing import Type

import ray

from tests.tools import TensorBoardParser, get_template_config
from tests.tools import TensorBoardParser, get_checkpoint_path, get_template_config
from trinity.common.config import Config
from trinity.common.workflows import WORKFLOWS
from trinity.common.constants import PLUGIN_DIRS_ENV_VAR
from trinity.common.workflows import WORKFLOWS, Workflow
from trinity.utils.monitor import MONITOR
from trinity.utils.plugin_loader import load_plugins


class PluginActor:
def __init__(self, config: Config):
def __init__(
self,
config: Config,
enable_load_plugins: bool = True,
enable_monitor: bool = True,
enable_workflow: bool = True,
):
if enable_load_plugins:
load_plugins()
self.config = config
self.monitor = MONITOR.get("my_monitor")(
project=self.config.project,
group=self.config.group,
name=self.config.name,
role=self.config.explorer.name,
config=config,
)
if enable_monitor:
self.monitor = MONITOR.get("my_monitor")(
project=self.config.project,
group=self.config.group,
name=self.config.name,
role=self.config.explorer.name,
config=config,
)
else:
self.monitor = None
if enable_workflow:
workflow = WORKFLOWS.get("my_workflow")
assert workflow is not None, "Workflow 'my_workflow' not found in registry"

def run(self):
my_plugin_cls = WORKFLOWS.get("my_workflow")
self.monitor.log({"rollout": 2}, step=1, commit=True)
return my_plugin_cls(task=None, model=None).run()
def run(self, workflow_cls=Type[Workflow]):
if self.monitor:
self.monitor.log({"rollout": 2}, step=1, commit=True)
return workflow_cls(task=None, model=None).run()


class TestPluginLoader(unittest.TestCase):
def test_load_plugins(self):
ray.init(ignore_reinit_error=True)
config = get_template_config()
my_plugin_cls = WORKFLOWS.get("my_workflow")
self.assertIsNone(my_plugin_cls)
load_plugins(Path(__file__).resolve().parent / "plugins")
my_plugin_cls = WORKFLOWS.get("my_workflow")
self.assertIsNotNone(my_plugin_cls)
my_plugin = my_plugin_cls(task=None, model=None, auxiliary_models=None)
def test_load_plugins_local(self):
my_workflow_cls = WORKFLOWS.get("my_workflow")
self.assertIsNone(my_workflow_cls)
os.environ[PLUGIN_DIRS_ENV_VAR] = str(Path(__file__).resolve().parent / "plugins")
try:
load_plugins()
except KeyError:
# already registered in next test
pass
my_workflow_cls = WORKFLOWS.get("my_workflow")
self.assertIsNotNone(my_workflow_cls)
my_plugin = my_workflow_cls(task=None, model=None, auxiliary_models=None)
self.assertTrue(my_plugin.__module__.startswith("trinity.plugins"))
res = my_plugin.run()
self.assertEqual(res[0], "Hello world")
self.assertEqual(res[1], "Hi")

# Remote Actor test
def test_load_plugins_remote(self):
os.environ[PLUGIN_DIRS_ENV_VAR] = str(Path(__file__).resolve().parent / "plugins")
try:
load_plugins()
except KeyError:
# already registered in previous test
pass
config = self.config
ray.init(
ignore_reinit_error=True,
runtime_env={
"env_vars": {PLUGIN_DIRS_ENV_VAR: str(Path(__file__).resolve().parent / "plugins")}
},
)
my_workflow_cls = WORKFLOWS.get("my_workflow")
# disable plugin and use custom class from registry
remote_plugin = ray.remote(PluginActor).remote(config, enable_load_plugins=False)
remote_plugin.run.remote(my_workflow_cls)
with self.assertRaises(ray.exceptions.ActorDiedError):
ray.get(remote_plugin.__ray_ready__.remote())

# enable plugin
remote_plugin = ray.remote(PluginActor).remote(config)
remote_res = ray.get(remote_plugin.run.remote())
remote_res = ray.get(remote_plugin.run.remote(my_workflow_cls))
self.assertEqual(remote_res[0], "Hello world")
self.assertEqual(remote_res[1], "Hi")

# test custom monitor
parser = TensorBoardParser(os.path.join(config.monitor.cache_dir, "tensorboard"))
rollout_cnt = parser.metric_values("rollout")
self.assertEqual(rollout_cnt, [2])

def test_passing_custom_class(self):
# disable plugin and pass custom class directly
os.environ[PLUGIN_DIRS_ENV_VAR] = str(Path(__file__).resolve().parent / "plugins")
try:
load_plugins()
except KeyError:
# already registered in previous test
pass
my_workflow_cls = WORKFLOWS.get("my_workflow")
remote_plugin = ray.remote(PluginActor).remote(
self.config, enable_load_plugins=False, enable_monitor=False, enable_workflow=False
)
remote_res = ray.get(remote_plugin.run.remote(my_workflow_cls))
self.assertEqual(remote_res[0], "Hello world")
ray.shutdown(_exiting_interpreter=True)

def setUp(self):
self.config = get_template_config()
self.config.checkpoint_root_dir = get_checkpoint_path()
self.config.check_and_update()
shutil.rmtree(self.config.monitor.cache_dir, ignore_errors=True)
4 changes: 0 additions & 4 deletions trinity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,3 @@
"""Trinity-RFT (Reinforcement Fine-Tuning)"""

__version__ = "0.2.1.dev0"

from trinity.utils.plugin_loader import load_plugins

load_plugins()
7 changes: 3 additions & 4 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import ray

from trinity.common.config import Config, load_config
from trinity.common.constants import DataProcessorPipelineType
from trinity.common.constants import PLUGIN_DIRS_ENV_VAR, DataProcessorPipelineType
from trinity.data.utils import (
activate_data_processor,
stop_data_processor,
Expand Down Expand Up @@ -186,9 +186,8 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
f"{data_processor_config.data_processor_url}/{DataProcessorPipelineType.EXPERIENCE.value}",
config_path,
)
envs = os.environ.copy()
all_plugin_dirs = [d for d in (plugin_dir, envs.get("PLUGIN_DIRS")) if d]
envs["PLUGIN_DIRS"] = os.pathsep.join(all_plugin_dirs)

envs = {PLUGIN_DIRS_ENV_VAR: plugin_dir}
if dlc:
from trinity.utils.dlc_utils import setup_ray_cluster

Expand Down
2 changes: 2 additions & 0 deletions trinity/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

ROLLOUT_WEIGHT_SYNC_GROUP_NAME = "rollout_weight_sync"

PLUGIN_DIRS_ENV_VAR = "TRINITY_PLUGIN_DIRS"


# enumerate types

Expand Down
2 changes: 2 additions & 0 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@
from trinity.manager.synchronizer import Synchronizer
from trinity.utils.log import get_logger
from trinity.utils.monitor import MONITOR, gather_metrics
from trinity.utils.plugin_loader import load_plugins


class Explorer:
"""Responsible for exploring the taskset."""

def __init__(self, config: Config):
self.logger = get_logger(__name__)
load_plugins()
self.cache = CacheManager(config)
explorer_meta = self.cache.load_explorer()
self.explore_step_num = explorer_meta.get("latest_iteration", 0)
Expand Down
5 changes: 1 addition & 4 deletions trinity/explorer/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Scheduler for rollout tasks."""

import asyncio
import os
import re
import time
import traceback
Expand Down Expand Up @@ -50,15 +49,13 @@ def __init__(
self.runner = self._create_runner()

def _create_runner(self):
envs = os.environ.copy()
envs.update(self.config.explorer.env_vars)
return (
ray.remote(WorkflowRunner)
.options(
namespace=self.namespace,
scheduling_strategy="SPREAD",
runtime_env={
"env_vars": envs,
"env_vars": self.config.explorer.env_vars,
},
)
.remote(self.config, self.rollout_model, self.auxiliary_models, self.runner_id)
Expand Down
2 changes: 1 addition & 1 deletion trinity/explorer/workflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
auxiliary_models: Optional[List[InferenceModel]] = None,
runner_id: Optional[int] = None,
) -> None:
self.logger = get_logger(__name__)
self.config = config
self.experience_buffer = get_buffer_writer(
self.config.buffer.explorer_output, # type: ignore
Expand All @@ -52,7 +53,6 @@ def __init__(
"vllm_async",
).get_openai_client()
self.auxiliary_models.append(api_client)
self.logger = get_logger(__name__)
self.workflow_instance = None
self.runner_id = runner_id
self.return_experiences = self.config.explorer.collect_experiences
Expand Down
2 changes: 2 additions & 0 deletions trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from trinity.manager.synchronizer import Synchronizer
from trinity.utils.log import get_logger
from trinity.utils.monitor import MONITOR
from trinity.utils.plugin_loader import load_plugins


class Trainer:
Expand All @@ -28,6 +29,7 @@ class Trainer:
def __init__(self, config: Config) -> None:
self.config = config
self.logger = get_logger(__name__)
load_plugins()
self.synchronizer = Synchronizer.get_actor(config)
self.engine = get_trainer_wrapper(config)
self.last_trainer_sync_step = 0
Expand Down
3 changes: 2 additions & 1 deletion trinity/trainer/verl/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from trinity.algorithm.kl_fn.kl_fn import DummyKLFn
from trinity.algorithm.utils import prefix_metrics
from trinity.common.config import AlgorithmConfig
from trinity.utils.plugin_loader import load_plugins

__all__ = ["DataParallelPPOActor"]

Expand All @@ -48,7 +49,7 @@ def __init__(
):
"""When optimizer is None, it is Reference Policy"""
super().__init__(config, actor_module, actor_optimizer)

load_plugins()
self.policy_loss_fn = None
self.kl_loss_fn = None
self.entropy_loss_fn = None
Expand Down
30 changes: 16 additions & 14 deletions trinity/utils/plugin_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,37 @@
from pathlib import Path
from typing import List, Union

from trinity.common.constants import PLUGIN_DIRS_ENV_VAR
from trinity.utils.log import get_logger

logger = get_logger(__name__)

loaded_dirs = set()

def load_plugins() -> None:
"""
Load plugin modules from the default plugin directory or directories specified in the environment variable.
If the environment variable `PLUGIN_DIRS_ENV_VAR` is not set, it defaults to `trinity/plugins`.
"""
plugin_dirs = os.environ.get(PLUGIN_DIRS_ENV_VAR, "").split(os.pathsep)
if not plugin_dirs or plugin_dirs == [""]:
plugin_dirs = [str(Path(__file__).parent.parent / "plugins")]

load_plugin_from_dirs(plugin_dirs)

def load_plugins(plugin_dirs: Union[str, List[str]] = None) -> None:

def load_plugin_from_dirs(plugin_dirs: Union[str, List[str]]) -> None:
"""
Load plugin modules from a directory.
"""
global loaded_dirs
if plugin_dirs is None:
plugin_dirs = [Path(__file__).parent.parent / "plugins"]
for plugin_dir in os.environ.get("PLUGIN_DIRS", "").split(os.pathsep):
plugin_dir = plugin_dir.strip()
if plugin_dir:
plugin_dirs.append(plugin_dir)
if not isinstance(plugin_dirs, list):
plugin_dirs = [plugin_dirs]
plugin_dirs = set(plugin_dirs)
for plugin_dir in plugin_dirs:
if plugin_dir in loaded_dirs:
continue
loaded_dirs.add(plugin_dir)
if not os.path.exists(plugin_dir):
logger.error(f"--plugin-dir [{plugin_dir}] does not exist.")
logger.error(f"plugin-dir [{plugin_dir}] does not exist.")
continue
if not os.path.isdir(plugin_dir):
logger.error(f"--plugin-dir [{plugin_dir}] is not a directory.")
logger.error(f"plugin-dir [{plugin_dir}] is not a directory.")
continue

logger.info(f"Loading plugin modules from [{plugin_dir}]...")
Expand Down