Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dependencies = [
"jsonlines",
"sortedcontainers",
"word2number",
"transformers<4.54.0", # TODO: remove when https://github.com/vllm-project/vllm-ascend/issues/2046 is fixed
]

[project.scripts]
Expand Down
26 changes: 24 additions & 2 deletions tests/utils/plugin_test.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,37 @@
import os
import unittest
from pathlib import Path

import ray

from tests.tools import TensorBoardParser, get_template_config
from trinity.common.config import Config
from trinity.common.workflows import WORKFLOWS
from trinity.utils.monitor import MONITOR
from trinity.utils.plugin_loader import load_plugins


@ray.remote
class PluginActor:
def __init__(self, config: Config):
self.config = config
self.monitor = MONITOR.get("my_monitor")(
project=self.config.project,
group=self.config.group,
name=self.config.name,
role=self.config.explorer.name,
config=config,
)

def run(self):
my_plugin_cls = WORKFLOWS.get("my_workflow")
self.monitor.log({"rollout": 2}, step=1, commit=True)
return my_plugin_cls(task=None, model=None).run()


class TestPluginLoader(unittest.TestCase):
def test_load_plugins(self):
ray.init(ignore_reinit_error=True)
config = get_template_config()
my_plugin_cls = WORKFLOWS.get("my_workflow")
self.assertIsNone(my_plugin_cls)
load_plugins(Path(__file__).resolve().parent / "plugins")
Expand All @@ -27,8 +42,15 @@ def test_load_plugins(self):
res = my_plugin.run()
self.assertEqual(res[0], "Hello world")
self.assertEqual(res[1], "Hi")
remote_plugin = PluginActor.remote()

# Remote Actor test
remote_plugin = ray.remote(PluginActor).remote(config)
remote_res = ray.get(remote_plugin.run.remote())
self.assertEqual(remote_res[0], "Hello world")
self.assertEqual(remote_res[1], "Hi")

# test custom monitor
parser = TensorBoardParser(os.path.join(config.monitor.cache_dir, "tensorboard"))
rollout_cnt = parser.metric_values("rollout")
self.assertEqual(rollout_cnt, [2])
ray.shutdown(_exiting_interpreter=True)
30 changes: 30 additions & 0 deletions tests/utils/plugins/my_monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os

from torch.utils.tensorboard import SummaryWriter

from trinity.common.config import Config
from trinity.utils.log import get_logger
from trinity.utils.monitor import MONITOR, Monitor


@MONITOR.register_module("my_monitor")
class MyMonitor(Monitor):
def __init__(
self, project: str, group: str, name: str, role: str, config: Config = None
) -> None:
self.tensorboard_dir = os.path.join(config.monitor.cache_dir, "tensorboard", role)
os.makedirs(self.tensorboard_dir, exist_ok=True)
self.logger = SummaryWriter(self.tensorboard_dir)
self.console_logger = get_logger(__name__)

def log_table(self, table_name: str, experiences_table, step: int):
pass

def log(self, data: dict, step: int, commit: bool = False) -> None:
"""Log metrics."""
for key in data:
self.logger.add_scalar(key, data[key], step)
self.console_logger.info(f"Step {step}: {data}")

def close(self) -> None:
self.logger.close()
4 changes: 4 additions & 0 deletions trinity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,7 @@
"""Trinity-RFT (Reinforcement Fine-Tuning)"""

__version__ = "0.2.1.dev0"

from trinity.utils.plugin_loader import load_plugins

load_plugins()
11 changes: 7 additions & 4 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from trinity.explorer.explorer import Explorer
from trinity.trainer.trainer import Trainer
from trinity.utils.log import get_logger
from trinity.utils.plugin_loader import load_plugins

logger = get_logger(__name__)

Expand Down Expand Up @@ -159,7 +158,6 @@ def both(config: Config) -> None:


def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
load_plugins(plugin_dir)
config = load_config(config_path)
config.check_and_update()
pprint(config)
Expand Down Expand Up @@ -188,16 +186,21 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
f"{data_processor_config.data_processor_url}/{DataProcessorPipelineType.EXPERIENCE.value}",
config_path,
)
envs = os.environ.copy()
all_plugin_dirs = [d for d in (plugin_dir, envs.get("PLUGIN_DIRS")) if d]
envs["PLUGIN_DIRS"] = os.pathsep.join(all_plugin_dirs)
if dlc:
from trinity.utils.dlc_utils import setup_ray_cluster

setup_ray_cluster(namespace=config.ray_namespace)
setup_ray_cluster(namespace=config.ray_namespace, envs=envs)
else:
from trinity.utils.dlc_utils import is_running

if not is_running:
raise RuntimeError("Ray is not running, please start it by `ray start --head`.")
ray.init(namespace=config.ray_namespace, ignore_reinit_error=True)
ray.init(
namespace=config.ray_namespace, ignore_reinit_error=True, runtime_env={"env_vars": envs}
)
try:
if config.mode == "explore":
explore(config)
Expand Down
2 changes: 1 addition & 1 deletion trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, config: Config):
self.logger.info("Finished initializing Explorer.")
self.collect_experiences = self.config.explorer.collect_experiences
self.generated_experience_cnt = 0
if self.collect_experiences:
if self.collect_experiences and self.config.mode != "bench":
assert (
self.experience_buffer is not None
), "Experience buffer is required when collect_experiences is True."
Expand Down
5 changes: 4 additions & 1 deletion trinity/explorer/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Scheduler for rollout tasks."""

import asyncio
import os
import re
import time
import traceback
Expand Down Expand Up @@ -49,13 +50,15 @@ def __init__(
self.runner = self._create_runner()

def _create_runner(self):
envs = os.environ.copy()
envs.update(self.config.explorer.env_vars)
return (
ray.remote(WorkflowRunner)
.options(
namespace=self.namespace,
scheduling_strategy="SPREAD",
runtime_env={
"env_vars": self.config.explorer.env_vars,
"env_vars": envs,
},
)
.remote(self.config, self.rollout_model, self.auxiliary_models, self.runner_id)
Expand Down
5 changes: 3 additions & 2 deletions trinity/utils/dlc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ def wait_for_ray_worker_nodes(world_size: int) -> None:
time.sleep(1)


def setup_ray_cluster(namespace: str):
def setup_ray_cluster(namespace: str, envs: dict):
env_vars = get_dlc_env_vars()
is_master = env_vars["RANK"] == 0

if is_running():
# reuse existing ray cluster
if is_master:
ray.init(namespace=namespace, ignore_reinit_error=True)
ray.init(namespace=namespace, ignore_reinit_error=True, runtime_env={"env_vars": envs})
else:
if is_master:
cmd = f"ray start --head --port={env_vars['MASTER_PORT']} --node-ip-address={env_vars['MASTER_ADDR']}"
Expand All @@ -91,6 +91,7 @@ def setup_ray_cluster(namespace: str):
address=f"{env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}",
namespace=namespace,
ignore_reinit_error=True,
runtime_env={"env_vars": envs},
)
if is_master:
# master wait for worker nodes to join
Expand Down
46 changes: 30 additions & 16 deletions trinity/utils/plugin_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,46 @@
import shutil
import sys
from pathlib import Path
from typing import List, Union

from trinity.utils.log import get_logger

logger = get_logger(__name__)

loaded_dirs = set()

def load_plugins(plugin_dir: str) -> None:

def load_plugins(plugin_dirs: Union[str, List[str]] = None) -> None:
"""
Load plugin modules from a directory.
"""
if plugin_dir is None:
plugin_dir = Path(__file__).parent.parent / "plugins"
if not os.path.exists(plugin_dir):
logger.error(f"--plugin-dir [{plugin_dir}] does not exist.")
return None
if not os.path.isdir(plugin_dir):
logger.error(f"--plugin-dir [{plugin_dir}] is not a directory.")
return None

logger.info(f"Loading plugin modules from [{plugin_dir}]...")
for file in Path(plugin_dir).glob("*.py"):
if file.name.startswith("__"):
global loaded_dirs
if plugin_dirs is None:
plugin_dirs = [Path(__file__).parent.parent / "plugins"]
for plugin_dir in os.environ.get("PLUGIN_DIRS", "").split(os.pathsep):
plugin_dir = plugin_dir.strip()
if plugin_dir:
plugin_dirs.append(plugin_dir)
if not isinstance(plugin_dirs, list):
plugin_dirs = [plugin_dirs]
for plugin_dir in plugin_dirs:
if plugin_dir in loaded_dirs:
continue
loaded_dirs.add(plugin_dir)
if not os.path.exists(plugin_dir):
logger.error(f"--plugin-dir [{plugin_dir}] does not exist.")
continue
logger.info(f"Loading plugin modules from [{file}]...")
# load modules from file
load_from_file(os.path.join(plugin_dir, file))
if not os.path.isdir(plugin_dir):
logger.error(f"--plugin-dir [{plugin_dir}] is not a directory.")
continue

logger.info(f"Loading plugin modules from [{plugin_dir}]...")
for file in Path(plugin_dir).glob("*.py"):
if file.name.startswith("__"):
continue
logger.info(f"Loading plugin modules from [{file}]...")
# load modules from file
load_from_file(os.path.join(plugin_dir, file))


def load_from_file(file_path: str):
Expand Down