From 8ecb77d82f78652ad9059db799e6f1523f8c415b Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 25 Jun 2025 16:03:31 +0800 Subject: [PATCH 1/9] support multiple explorer --- .../source/tutorial/trinity_configs.md | 43 ++++++++- tests/buffer/file_test.py | 2 +- tests/buffer/queue_test.py | 3 +- tests/buffer/sql_test.py | 3 + tests/tools.py | 2 +- tests/trainer/trainer_test.py | 91 ++++++++++++++++++- trinity/buffer/buffer_writer.py | 16 +++- trinity/buffer/queue.py | 39 +++++++- trinity/buffer/ray_wrapper.py | 35 ++++++- trinity/buffer/reader/queue_reader.py | 10 +- trinity/buffer/writer/file_writer.py | 13 ++- trinity/buffer/writer/queue_writer.py | 17 ++-- trinity/buffer/writer/sql_writer.py | 15 ++- trinity/common/config.py | 8 +- trinity/data/core/dataset.py | 2 +- trinity/explorer/explorer.py | 3 +- 16 files changed, 254 insertions(+), 48 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index f6a6d8c780..6b9abdeb25 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -68,6 +68,7 @@ checkpoint_root_dir: /PATH/TO/CHECKPOINT - `explore`: Only launches the explorer. - `bench`: Used for benchmarking. - `checkpoint_root_dir`: Root directory where all checkpoints and logs will be saved. Checkpoints for this experiment will be stored in `///`. +- `ray_namespace`: Namespace for the modules launched in the current experiment. If not specified, it will be set to `/`. --- @@ -166,6 +167,9 @@ buffer: eval_tasksets: ... + explorer_output: + ... + trainer_input: experience_buffer: ... @@ -225,9 +229,9 @@ The configuration for each task dataset is defined as follows: - `queue`: The dataset is stored in a queue. The queue is a simple FIFO queue that stores the task dataset. *Do not use this storage type for task dataset unless you know what you are doing.* - `sql`: The dataset is stored in a SQL database. *This type is unstable and will be optimized in the future versions.* - `path`: The path to the task dataset. - - For `file` storage type, the path is the path to the directory that contains the task dataset files. + - For `file` storage type, the path points to the directory that contains the task dataset files. - For `queue` storage type, the path is optional. You can back up the data in the queue by specifying a sqlite database path here. - - For `sql` storage type, the path is the path to the sqlite database file. + - For `sql` storage type, the path points to the sqlite database file. - `subset_name`: The subset name of the task dataset. Default is `None`. - `split`: The split of the task dataset. Default is `train`. - `format`: Defines keys for prompts and responses in the dataset. @@ -240,6 +244,37 @@ The configuration for each task dataset is defined as follows: - `workflow_args`: A dictionary of arguments used to supplement dataset-level parameters. +### Explorer Output + +In [`explore` mode](#global-configuration), since there is no trainer, users can configure an experience buffer via `buffer.explorer_input`, rather than using `buffer.trainer_input`, which will be introduced in the next section. + +> For `both` and `train` modes, users should use `buffer.trainer_input` instead of `buffer.explorer_output`. + +```yaml +buffer: + ... + explorer_output: + name: countdown_buffer + storage_type: queue + path: sqlite:///countdown_buffer.db + wrap_in_ray: True + ray_namespace: Trinity-RFT/example +``` + +- `name`: The name of the experience buffer. This name will be used as the Ray actor's name, so it must be unique. +- `storage_type`: The storage type for the experience buffer. + - `queue`: Experience data is stored in a queue. This storage type is recommended for most use cases. + - `sql`: Experience data is stored in a SQL database. If your database only supports local access (e.g., SQLite), set `wrap_in_ray` to `True` to wrap the database in a Ray actor, enabling remote access from other nodes. + - `file`: Experience data is stored in a JSON file. This storage type should be used only for debugging purposes in `explore` mode. + - For `queue` storage type, this field is optional. You can specify a SQLite database or JSON file path here to back up the queue data. + - For `file` storage type, the path points to the directory containing the dataset files. + - For `sql` storage type, the path points to the SQLite database file. + +- `wrap_in_ray`: Whether to wrap the experience buffer in a Ray actor. Only take effect when `storage_type` is `sql` or `file`. The `queue` storage always uses a Ray actor. + +- `ray_namespace`: The Ray namespace of the experience buffer. If you want to connect to an existing experience buffer launched by another experiment, set this value to the `ray_namespace` of the target experiment and provide the corresponding buffer’s `name`. Otherwise, this field can be omitted. + + ### Trainer Input Defines the experience buffer and optional SFT warm-up dataset. @@ -264,7 +299,7 @@ buffer: sft_warmup_steps: 0 ``` -- `experience_buffer`: Experience replay buffer used by the trainer. +- `experience_buffer`: Experience buffer used by the trainer, which is logically equivalent to `buffer.explorer_output`. - `sft_warmup_dataset`: Optional dataset used for pre-training (SFT warmup). - `sft_warmup_steps`: Number of steps to use SFT warm-up before RL begins. @@ -301,6 +336,7 @@ Controls how model weights are synchronized between trainer and explorer. synchronizer: sync_method: 'nccl' sync_interval: 10 + sync_offset: 0 sync_timeout: 1200 ``` @@ -308,6 +344,7 @@ synchronizer: - `nccl`: Uses NCCL for fast synchronization. Supported for `both` mode. - `checkpoint`: Loads latest model from disk. Supported for `train`, `explore`, or `bench` mode. - `sync_interval`: Interval (in steps) of model weight synchronization between trainer and explorer. +- `sync_offset`: Offset (in steps) of model weight synchronization between trainer and explorer. The explorer can run `sync_offset` steps before the trainer starts training. - `sync_timeout`: Timeout duration for synchronization. --- diff --git a/tests/buffer/file_test.py b/tests/buffer/file_test.py index e53669a850..2882dd8e0f 100644 --- a/tests/buffer/file_test.py +++ b/tests/buffer/file_test.py @@ -47,7 +47,7 @@ def test_file_buffer(self): # test writer writer = JSONWriter(meta, None) writer.write(data) - writer.finish() + writer.release() # test reader meta.path = self.temp_output_path diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py index 03e96e4291..23271c6158 100644 --- a/tests/buffer/queue_test.py +++ b/tests/buffer/queue_test.py @@ -30,6 +30,7 @@ def test_queue_buffer(self): ) writer = QueueWriter(meta, config) reader = QueueReader(meta, config) + self.assertEqual(writer.acquire(), 1) exps = [ Experience( tokens=torch.tensor([float(j) for j in range(i + 1)]), @@ -59,7 +60,7 @@ def test_queue_buffer(self): ) exps = reader.read(batch_size=put_batch_size * 2) self.assertEqual(len(exps), put_batch_size * 2) - writer.finish() + self.assertEqual(writer.release(), 0) self.assertRaises(StopIteration, reader.read) with open(BUFFER_FILE_PATH, "r") as f: self.assertEqual(len(f.readlines()), total_num + put_batch_size * 2) diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index 56305be671..e40a91b4c7 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -42,6 +42,7 @@ def test_create_sql_buffer(self) -> None: ) for i in range(1, put_batch_size + 1) ] + self.assertEqual(sql_writer.acquire(), 1) for _ in range(total_num // put_batch_size): sql_writer.write(exps) for _ in range(total_num // read_batch_size): @@ -65,3 +66,5 @@ def test_create_sql_buffer(self) -> None: self.assertEqual(len(exps), put_batch_size * 2) db_wrapper = ray.get_actor("sql-test_buffer") self.assertIsNotNone(db_wrapper) + self.assertEqual(sql_writer.release(), 0) + self.assertRaises(StopIteration, sql_reader.read) diff --git a/tests/tools.py b/tests/tools.py index 209b5eb1c2..5f7cd1785d 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -42,7 +42,7 @@ def get_checkpoint_path() -> str: def get_unittest_dataset_config( dataset_name: str = "countdown", split: str = "train" ) -> StorageConfig: - """Countdown sample dataset for 8 steps""" + """Countdown sample dataset for 4 steps""" if dataset_name == "countdown" or dataset_name == "copy_countdown": return StorageConfig( name=dataset_name, diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 32d19e9190..2a782812be 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -1,7 +1,10 @@ """Tests for trainer.""" import os import shutil +import time +import unittest from abc import abstractmethod +from copy import deepcopy from datetime import datetime import ray @@ -14,8 +17,9 @@ get_template_config, get_unittest_dataset_config, ) -from trinity.cli.launcher import bench, both, train -from trinity.common.constants import SyncMethod +from trinity.cli.launcher import bench, both, explore, train +from trinity.common.config import Config, StorageConfig +from trinity.common.constants import StorageType, SyncMethod class BaseTrainerCase(RayUnittestBase): @@ -149,7 +153,6 @@ def test_trainer(self): response_metrics = parser.metric_list("response_length") self.assertTrue(len(response_metrics) > 0) self.assertEqual(parser.metric_max_step(response_metrics[0]), 4) - ray.timeline(filename="timeline.json") ray.shutdown(_exiting_interpreter=True) # check checkpoint from trinity.common.models.utils import get_checkpoint_dir_with_step_num @@ -163,7 +166,8 @@ def test_trainer(self): def tearDown(self): # remove dir only when the test passed - shutil.rmtree(self.config.checkpoint_job_dir) + pass + # shutil.rmtree(self.config.checkpoint_job_dir) class TestTrainerGSM8K(BaseTrainerCase): @@ -262,3 +266,82 @@ def test_trainer(self): def tearDown(self): # remove dir only when the test passed shutil.rmtree(self.config.checkpoint_job_dir) + + +def run_trainer(config: Config) -> None: + ray.init(namespace=config.ray_namespace) + train(config) + + +def run_explorer(config: Config) -> None: + ray.init(namespace=config.ray_namespace) + explore(config) + + +class TestFullyAsyncMode(unittest.TestCase): + def test_fully_async_mode(self): + trainer_name = f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}" + config = get_template_config() + config.project = "unittest" + config.checkpoint_root_dir = get_checkpoint_path() + config.buffer.total_epochs = 1 + config.buffer.batch_size = 4 + config.cluster.gpu_per_node = 2 + config.cluster.node_num = 1 + config.model.model_path = get_model_path() + config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") + config.buffer.trainer_input.experience_buffer = StorageConfig( + name="exp_buffer", + storage_type=StorageType.QUEUE, + wrap_in_ray=True, + ) + config.synchronizer.sync_method = SyncMethod.CHECKPOINT + config.synchronizer.sync_interval = 4 + config.monitor.monitor_type = "tensorboard" + trainer_config = deepcopy(config) + trainer_config.mode = "train" + trainer_config.name = trainer_name + trainer_config.check_and_update() + + explorer1_config = deepcopy(config) + explorer1_config.mode = "explore" + config.cluster.gpu_per_node = 1 + config.cluster.node_num = 1 + explorer1_config.explorer.rollout_model.engine_num = 1 + explorer1_config.explorer.rollout_model.tensor_parallel_size = 1 + explorer1_config.explorer.runner_num = 4 + explorer1_config.name = f"explorer1-{datetime.now().strftime('%Y%m%d%H%M%S')}" + explorer1_config.buffer.explorer_output = StorageConfig( + name="exp_buffer", + storage_type=StorageType.QUEUE, + wrap_in_ray=True, + ray_namespace=f"unittest/{trainer_config.name}", + ) + explorer2_config = deepcopy(explorer1_config) + explorer1_config.check_and_update() + + import multiprocessing + + multiprocessing.set_start_method("spawn") + + trainer_process = multiprocessing.Process(target=run_trainer, args=(trainer_config,)) + trainer_process.start() + + ray.init(ignore_reinit_error=True) + while True: + try: + ray.get_actor("queue-exp_buffer", namespace=trainer_config.ray_namespace) + break + except ValueError: + print("waiting for trainer to start.") + time.sleep(5) + + explorer_process_1 = multiprocessing.Process(target=run_explorer, args=(explorer1_config,)) + explorer_process_1.start() + + explorer_process_2 = multiprocessing.Process(target=run_explorer, args=(explorer2_config,)) + explorer_process_2.start() + + explorer_process_1.join() + explorer_process_2.join() + trainer_process.join() diff --git a/trinity/buffer/buffer_writer.py b/trinity/buffer/buffer_writer.py index ac245f50b6..13079ffb76 100644 --- a/trinity/buffer/buffer_writer.py +++ b/trinity/buffer/buffer_writer.py @@ -11,5 +11,17 @@ def write(self, data: List) -> None: """Write to buffer.""" @abstractmethod - def finish(self) -> None: - """Finish writing.""" + def acquire(self) -> int: + """Acquire the buffer writer. + + Returns: + `int`: The reference count of the buffer after acquiring. + """ + + @abstractmethod + def release(self) -> int: + """Release the buffer writer. After release, the buffer writer can not be used again. + + Returns: + `int`: The reference count of the buffer after releasing. + """ diff --git a/trinity/buffer/queue.py b/trinity/buffer/queue.py index a3db72ef90..3e6ba4f329 100644 --- a/trinity/buffer/queue.py +++ b/trinity/buffer/queue.py @@ -3,6 +3,8 @@ from copy import deepcopy from typing import List +import ray + from trinity.buffer.writer.file_writer import JSONWriter from trinity.buffer.writer.sql_writer import SQLWriter from trinity.common.config import BufferConfig, StorageConfig @@ -44,6 +46,19 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: st_config.storage_type = StorageType.FILE self.writer = JSONWriter(st_config, self.config) self.logger.warning(f"Save experiences in {st_config.path}.") + self.ref_count = 0 + + async def acquire(self) -> int: + self.ref_count += 1 + return self.ref_count + + async def release(self) -> int: + """Release the queue.""" + self.ref_count -= 1 + if self.ref_count <= 0: + await self.queue.put(self.FINISH_MESSAGE) + self.writer.release() + return self.ref_count def length(self) -> int: """The length of the queue.""" @@ -55,10 +70,6 @@ async def put_batch(self, exp_list: List) -> None: if self.writer is not None: self.writer.write(exp_list) - async def finish(self) -> None: - """Stop the queue.""" - await self.queue.put(self.FINISH_MESSAGE) - async def get_batch(self, batch_size: int) -> List: """Get batch of experience.""" batch = [] @@ -70,3 +81,23 @@ async def get_batch(self, batch_size: int) -> List: if len(batch) >= batch_size: break return batch + + @classmethod + def get_actor(cls, storage_config: StorageConfig, config: BufferConfig): + """Get the queue actor.""" + if storage_config.ray_namespace: + queue_actor = ray.get_actor( + f"queue-{storage_config.name}", + namespace=storage_config.ray_namespace, + ) + else: + queue_actor = ( + ray.remote(cls) + .options( + name=f"queue-{storage_config.name}", + namespace=ray.get_runtime_context().namespace, + get_if_exists=True, + ) + .remote(storage_config, config) + ) + return queue_actor diff --git a/trinity/buffer/ray_wrapper.py b/trinity/buffer/ray_wrapper.py index 71e9102999..4e3e4baf50 100644 --- a/trinity/buffer/ray_wrapper.py +++ b/trinity/buffer/ray_wrapper.py @@ -47,10 +47,16 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: self.batch_size = config.read_batch_size self.max_retry_times = config.max_retry_times self.max_retry_interval = config.max_retry_interval + self.ref_count = 0 + self.stopped = False @classmethod def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): if storage_config.wrap_in_ray: + if storage_config.ray_namespace: + return ray.get_actor( + f"sql-{storage_config.name}", namespace=storage_config.ray_namespace + ) return ( ray.remote(cls) .options( @@ -71,6 +77,9 @@ def write(self, data: list) -> None: def read( self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None ) -> List: + if self.stopped: + raise StopIteration() + if strategy is None: strategy = ReadStrategy.LFU @@ -114,6 +123,16 @@ def read( self.logger.info(f"first response_text = {exp_list[0].response_text}") return exp_list + def acquire(self) -> int: + self.ref_count += 1 + return self.ref_count + + def release(self) -> int: + self.ref_count -= 1 + if self.ref_count <= 0: + self.stopped = True + return self.ref_count + class _Encoder(json.JSONEncoder): def default(self, o): @@ -147,10 +166,15 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: os.makedirs(path_dir, exist_ok=True) self.file = open(storage_config.path, "a", encoding="utf-8") self.encoder = _Encoder(ensure_ascii=False) + self.ref_count = 0 @classmethod def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): if storage_config.wrap_in_ray: + if storage_config.ray_namespace: + return ray.get_actor( + f"json-{storage_config.name}", namespace=storage_config.ray_namespace + ) return ( ray.remote(cls) .options( @@ -174,5 +198,12 @@ def read(self) -> List: "read() is not implemented for FileWrapper, please use QUEUE instead" ) - def finish(self) -> None: - self.file.close() + def acquire(self) -> int: + self.ref_count += 1 + return self.ref_count + + def release(self) -> int: + self.ref_count -= 1 + if self.ref_count <= 0: + self.file.close() + return self.ref_count diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index 271c2931e2..6591ddde4a 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -19,15 +19,7 @@ class QueueReader(BufferReader): def __init__(self, storage_config: StorageConfig, config: BufferConfig): assert storage_config.storage_type == StorageType.QUEUE self.read_batch_size = config.read_batch_size - self.queue = ( - ray.remote(QueueActor) - .options( - name=f"queue-{storage_config.name}", - namespace=ray.get_runtime_context().namespace, - get_if_exists=True, - ) - .remote(storage_config, config) - ) + self.queue = QueueActor.get_actor(storage_config, config) def read( self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None diff --git a/trinity/buffer/writer/file_writer.py b/trinity/buffer/writer/file_writer.py index 0fc4929ca5..16ec96d0a9 100644 --- a/trinity/buffer/writer/file_writer.py +++ b/trinity/buffer/writer/file_writer.py @@ -20,8 +20,15 @@ def write(self, data: List) -> None: else: self.writer.write(data) - def finish(self): + def acquire(self) -> int: if self.wrap_in_ray: - ray.get(self.writer.finish.remote()) + return ray.get(self.writer.acquire()) else: - self.writer.finish() + return 0 + + def release(self) -> int: + if self.wrap_in_ray: + return ray.get(self.writer.release.remote()) + else: + self.writer.release() + return 0 diff --git a/trinity/buffer/writer/queue_writer.py b/trinity/buffer/writer/queue_writer.py index ec2316a0ec..7b12fab4c1 100644 --- a/trinity/buffer/writer/queue_writer.py +++ b/trinity/buffer/writer/queue_writer.py @@ -18,18 +18,13 @@ class QueueWriter(BufferWriter): def __init__(self, meta: StorageConfig, config: BufferConfig): assert meta.storage_type == StorageType.QUEUE self.config = config - self.queue = ( - ray.remote(QueueActor) - .options( - name=f"queue-{meta.name}", - namespace=ray.get_runtime_context().namespace, - get_if_exists=True, - ) - .remote(meta, config) - ) + self.queue = QueueActor.get_actor(meta, config) def write(self, data: List) -> None: ray.get(self.queue.put_batch.remote(data)) - def finish(self): - ray.get(self.queue.finish.remote()) + def acquire(self) -> int: + return ray.get(self.queue.acquire.remote()) + + def release(self) -> int: + return ray.get(self.queue.release.remote()) diff --git a/trinity/buffer/writer/sql_writer.py b/trinity/buffer/writer/sql_writer.py index 8864dc9b82..95344d4447 100644 --- a/trinity/buffer/writer/sql_writer.py +++ b/trinity/buffer/writer/sql_writer.py @@ -23,6 +23,15 @@ def write(self, data: list) -> None: else: self.db_wrapper.write(data) - def finish(self) -> None: - # TODO: implement this - pass + def acquire(self) -> int: + if self.wrap_in_ray: + return ray.get(self.db_wrapper.acquire.remote()) + else: + return 0 + + def release(self) -> int: + if self.wrap_in_ray: + return ray.get(self.db_wrapper.release.remote()) + else: + self.db_wrapper.release() + return 0 diff --git a/trinity/common/config.py b/trinity/common/config.py index 5d60cf8c4c..fb1534d721 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -79,12 +79,16 @@ class StorageConfig: format: FormatConfig = field(default_factory=FormatConfig) index: int = 0 - # used for StorageType.SQL + # used for StorageType.SQL/FILE wrap_in_ray: bool = True # used for StorageType.QUEUE capacity: int = 10000 + # used in Fully Async Mode, + # the explorer started later can connect to an existing buffer in different namespace + ray_namespace: Optional[str] = None + # used for rollout tasks default_workflow_type: Optional[str] = None default_reward_fn_type: Optional[str] = None @@ -582,7 +586,7 @@ def check_and_update(self) -> None: # noqa: C901 # set namespace if self.ray_namespace is None or len(self.ray_namespace) == 0: - self.ray_namespace = f"{self.project}-{self.name}" + self.ray_namespace = f"{self.project}/{self.name}" # check algorithm self._check_algorithm() diff --git a/trinity/data/core/dataset.py b/trinity/data/core/dataset.py index 93be832cc7..6b6d126f9b 100644 --- a/trinity/data/core/dataset.py +++ b/trinity/data/core/dataset.py @@ -84,7 +84,7 @@ def write_to_buffer( buffer_config = self.buffer_config output_buffer = get_buffer_writer(output_storage_config, buffer_config) output_buffer.write(self.data.to_list()) - output_buffer.finish() + output_buffer.release() self.data = Dataset.from_list([]) def to_parquet(self, path: str): diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 31ade5f84b..0ec7c7411b 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -47,6 +47,7 @@ def __init__(self, config: Config): self.config.buffer.explorer_output, # type: ignore self.config.buffer, ) + self.experience_buffer.acquire() self.config.buffer.explorer_input.taskset.index = explorer_meta.get("latest_task_index", 0) self.taskset = get_buffer_reader( self.config.buffer.explorer_input.taskset, self.config.buffer @@ -202,7 +203,7 @@ def explore_step(self) -> bool: ) self.status = RunningStatus.STOPPED self.wait_for_workflow_done() - self.experience_buffer.finish() + self.experience_buffer.release() return False self.runner_pool.run_tasks(tasks) self.explore_step_num += 1 From 3a62dac7241b6b810aa7f59731254eae3d1567bd Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 25 Jun 2025 17:35:45 +0800 Subject: [PATCH 2/9] fix checkpoint sync --- tests/tools.py | 2 +- tests/trainer/trainer_test.py | 61 ++++++++++++++++++++--- trinity/cli/launcher.py | 15 +++--- trinity/common/config.py | 4 ++ trinity/common/models/vllm_async_model.py | 2 + trinity/common/models/vllm_model.py | 2 + trinity/common/models/vllm_worker.py | 7 ++- trinity/common/verl_config.py | 3 ++ trinity/explorer/explorer.py | 6 +-- trinity/manager/manager.py | 4 +- trinity/trainer/trainer.py | 11 ++-- trinity/trainer/verl/fsdp_workers.py | 8 +-- trinity/trainer/verl_trainer.py | 3 +- trinity/utils/monitor.py | 2 +- 14 files changed, 90 insertions(+), 40 deletions(-) diff --git a/tests/tools.py b/tests/tools.py index 5f7cd1785d..0b4ffd5750 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -42,7 +42,7 @@ def get_checkpoint_path() -> str: def get_unittest_dataset_config( dataset_name: str = "countdown", split: str = "train" ) -> StorageConfig: - """Countdown sample dataset for 4 steps""" + """Countdown dataset with 16 samples.""" if dataset_name == "countdown" or dataset_name == "copy_countdown": return StorageConfig( name=dataset_name, diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 2a782812be..21a1b9f1b2 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -20,6 +20,8 @@ from trinity.cli.launcher import bench, both, explore, train from trinity.common.config import Config, StorageConfig from trinity.common.constants import StorageType, SyncMethod +from trinity.common.models.utils import get_checkpoint_dir_with_step_num +from trinity.manager.manager import CacheManager class BaseTrainerCase(RayUnittestBase): @@ -166,8 +168,7 @@ def test_trainer(self): def tearDown(self): # remove dir only when the test passed - pass - # shutil.rmtree(self.config.checkpoint_job_dir) + shutil.rmtree(self.config.checkpoint_job_dir) class TestTrainerGSM8K(BaseTrainerCase): @@ -280,9 +281,9 @@ def run_explorer(config: Config) -> None: class TestFullyAsyncMode(unittest.TestCase): def test_fully_async_mode(self): - trainer_name = f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}" config = get_template_config() config.project = "unittest" + config.name = f"fully_async_{datetime.now().strftime('%Y%m%d%H%M%S')}" config.checkpoint_root_dir = get_checkpoint_path() config.buffer.total_epochs = 1 config.buffer.batch_size = 4 @@ -300,22 +301,20 @@ def test_fully_async_mode(self): config.monitor.monitor_type = "tensorboard" trainer_config = deepcopy(config) trainer_config.mode = "train" - trainer_config.name = trainer_name trainer_config.check_and_update() explorer1_config = deepcopy(config) explorer1_config.mode = "explore" + explorer1_config.explorer.name = "explorer1" config.cluster.gpu_per_node = 1 config.cluster.node_num = 1 explorer1_config.explorer.rollout_model.engine_num = 1 explorer1_config.explorer.rollout_model.tensor_parallel_size = 1 explorer1_config.explorer.runner_num = 4 - explorer1_config.name = f"explorer1-{datetime.now().strftime('%Y%m%d%H%M%S')}" explorer1_config.buffer.explorer_output = StorageConfig( name="exp_buffer", storage_type=StorageType.QUEUE, wrap_in_ray=True, - ray_namespace=f"unittest/{trainer_config.name}", ) explorer2_config = deepcopy(explorer1_config) explorer1_config.check_and_update() @@ -339,9 +338,57 @@ def test_fully_async_mode(self): explorer_process_1 = multiprocessing.Process(target=run_explorer, args=(explorer1_config,)) explorer_process_1.start() + time.sleep(20) + explorer2_config.explorer.name = "explorer2" + explorer2_config.check_and_update() explorer_process_2 = multiprocessing.Process(target=run_explorer, args=(explorer2_config,)) explorer_process_2.start() explorer_process_1.join() explorer_process_2.join() - trainer_process.join() + + # wait for trainer process to finish. + trainer_process.join(timeout=200) + + # check the tensorboard + parser = TensorBoardParser( + os.path.join(trainer_config.monitor.cache_dir, "tensorboard", "trainer") + ) + actor_metrics = parser.metric_list("actor") + self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8) + parser = TensorBoardParser( + os.path.join(explorer1_config.monitor.cache_dir, "tensorboard", "explorer1") + ) + rollout_metrics = parser.metric_list("rollout") + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) + parser = TensorBoardParser( + os.path.join(explorer2_config.monitor.cache_dir, "tensorboard", "explorer2") + ) + rollout_metrics = parser.metric_list("rollout") + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) + # check the checkpoint + explorer1_cache = CacheManager(explorer1_config) + cache = explorer1_cache.load_explorer() + self.assertEqual(cache["latest_iteration"], 4) + explorer2_cache = CacheManager(explorer2_config) + cache = explorer2_cache.load_explorer() + self.assertEqual(cache["latest_iteration"], 4) + self.assertIsNotNone( + get_checkpoint_dir_with_step_num( + checkpoint_root_path=explorer1_config.checkpoint_job_dir, + trainer_type="verl", + step_num=8, + ) + ) + self.assertIsNotNone( + get_checkpoint_dir_with_step_num( + checkpoint_root_path=explorer2_config.checkpoint_job_dir, + trainer_type="verl", + step_num=8, + ) + ) + ray.shutdown() + + def tearDown(self): + checkpoint_path = get_checkpoint_path() + shutil.rmtree(os.path.join(checkpoint_path, "unittest")) diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index e4123820de..15b669d61c 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -9,7 +9,6 @@ import ray from trinity.common.config import Config, DataPipelineConfig, load_config -from trinity.common.constants import EXPLORER_NAME, TRAINER_NAME from trinity.explorer.explorer import Explorer from trinity.trainer.trainer import Trainer from trinity.utils.log import get_logger @@ -23,7 +22,7 @@ def bench(config: Config) -> None: explorer = ( ray.remote(Explorer) .options( - name=EXPLORER_NAME, + name=config.explorer.name, namespace=ray.get_runtime_context().namespace, ) .remote(config) @@ -44,7 +43,7 @@ def explore(config: Config) -> None: explorer = ( ray.remote(Explorer) .options( - name=EXPLORER_NAME, + name=config.explorer.name, namespace=ray.get_runtime_context().namespace, ) .remote(config) @@ -64,7 +63,7 @@ def train(config: Config) -> None: trainer = ( ray.remote(Trainer) .options( - name=TRAINER_NAME, + name=config.trainer.name, namespace=ray.get_runtime_context().namespace, ) .remote(config) @@ -92,7 +91,7 @@ def both(config: Config) -> None: explorer = ( ray.remote(Explorer) .options( - name=EXPLORER_NAME, + name=config.explorer.name, namespace=namespace, ) .remote(config) @@ -100,7 +99,7 @@ def both(config: Config) -> None: trainer = ( ray.remote(Trainer) .options( - name=TRAINER_NAME, + name=config.trainer.name, namespace=namespace, ) .remote(config) @@ -127,7 +126,7 @@ def both(config: Config) -> None: ) ready = ray.get(ready_ref[0]) - if ready == TRAINER_NAME: + if ready == config.trainer.name: logger.info( "===========================================================\n" "> Launcher detected that the `Trainer` process has finished.\n" @@ -135,7 +134,7 @@ def both(config: Config) -> None: "===========================================================" ) ray.wait(wait_ref, timeout=5) - elif ready == EXPLORER_NAME: + elif ready == config.explorer.name: logger.info( "============================================================\n" "> Launcher detected that the `Explorer` process has finished.\n" diff --git a/trinity/common/config.py b/trinity/common/config.py index fb1534d721..6f7ca59271 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -7,6 +7,8 @@ from omegaconf import OmegaConf from trinity.common.constants import ( + EXPLORER_NAME, + TRAINER_NAME, PromptType, ReadStrategy, StorageType, @@ -283,6 +285,7 @@ class BufferConfig: class ExplorerConfig: """Config for explorer.""" + name: str = EXPLORER_NAME # for workflow runner # number of workflow runners. # For sync engine (vllm), it should be equal to `engine_num`. @@ -304,6 +307,7 @@ class ExplorerConfig: @dataclass class TrainerConfig: + name: str = TRAINER_NAME trainer_type: str = "verl" save_interval: int = 0 enable_preview: bool = True # enable rollout preview in wandb diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py index 79c0cfae01..101d04764b 100644 --- a/trinity/common/models/vllm_async_model.py +++ b/trinity/common/models/vllm_async_model.py @@ -282,6 +282,7 @@ async def init_process_group( rank_offset: int, world_size: int, group_name: str, + explorer_name: str, backend: str = "nccl", timeout: int = 1200, update_with_checkpoint: bool = True, @@ -300,6 +301,7 @@ async def init_process_group( update_with_checkpoint, state_dict_meta, ray.get_runtime_context().namespace, + explorer_name, ), ) diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 3efe88b000..1e154da3a2 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -96,6 +96,7 @@ def init_process_group( rank_offset: int, world_size: int, group_name: str, + explorer_name: str, backend: str = "nccl", timeout: int = 1200, update_with_checkpoint: bool = True, @@ -114,6 +115,7 @@ def init_process_group( update_with_checkpoint, state_dict_meta, ray.get_runtime_context().namespace, + explorer_name, ), ) diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py index 2a156b8a2a..235ac1b013 100644 --- a/trinity/common/models/vllm_worker.py +++ b/trinity/common/models/vllm_worker.py @@ -4,7 +4,6 @@ import torch import torch.distributed -from trinity.common.constants import EXPLORER_NAME from trinity.utils.distributed import init_process_group, is_ipv6_address from trinity.utils.log import get_logger @@ -23,6 +22,7 @@ def init_process_group( timeout: int = 1200, update_with_checkpoint: bool = True, state_dict_meta: list = None, + explorer_name: str = None, namespace: str = None, ): """Init torch process group for model weights update""" @@ -53,6 +53,7 @@ def init_process_group( group_name=group_name, ) logger.info("vLLM init_process_group finished.") + self._explorer_name = explorer_name self._namespace = namespace self._explorer_actor = None @@ -63,7 +64,9 @@ def update_weight(self): """Broadcast weight to all vllm workers from source rank 0 (actor model)""" assert self._state_dict_meta is not None if self._explorer_actor is None: - self._explorer_actor = ray.get_actor(name=EXPLORER_NAME, namespace=self._namespace) + self._explorer_actor = ray.get_actor( + name=self._explorer_name, namespace=self._namespace + ) for name, dtype_str, shape in self._state_dict_meta: if self._weight_update_rank == 0: weight = ray.get(self._explorer_actor.get_weight.remote(name)) diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 1ec0653503..0d2d3bf8c1 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -5,6 +5,7 @@ from omegaconf import OmegaConf from trinity.common.config import BufferConfig, Config, SynchronizerConfig +from trinity.common.constants import EXPLORER_NAME from trinity.utils.log import get_logger logger = get_logger(__name__) @@ -119,6 +120,7 @@ class ActorRolloutRef: ref: Ref = field(default_factory=Ref) rollout: Rollout = field(default_factory=Rollout) synchronizer: Optional[SynchronizerConfig] = None + explorer_name: str = EXPLORER_NAME @dataclass @@ -298,6 +300,7 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 self.synchronizer = config.synchronizer self.actor_rollout_ref.synchronizer = config.synchronizer + self.actor_rollout_ref.explorer_name = config.explorer.name # Actor / Critic config self.actor_rollout_ref.model.path = config.model.model_path diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 0ec7c7411b..a35afa4c87 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -15,7 +15,6 @@ from trinity.buffer.buffer import get_buffer_reader from trinity.common.config import Config from trinity.common.constants import ( - EXPLORER_NAME, ROLLOUT_WEIGHT_SYNC_GROUP_NAME, RunningStatus, SyncMethod, @@ -56,7 +55,7 @@ def __init__(self, config: Config): self.monitor = MONITOR.get(self.config.monitor.monitor_type)( project=self.config.project, name=self.config.name, - role=EXPLORER_NAME, + role=self.config.explorer.name, config=config, ) self.batch_size = config.buffer.batch_size @@ -101,6 +100,7 @@ async def setup_weight_sync_group( + base_offset, world_size=world_size, group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME, + explorer_name=self.config.explorer.name, timeout=self.config.synchronizer.sync_timeout, update_with_checkpoint=self.use_checkpoint_weights_update, state_dict_meta=state_dict_meta, @@ -185,7 +185,7 @@ async def explore(self) -> str: self.logger.error(f"Error in Explorer: {e}") break self.logger.info("--------------------\n> Explorer finished.\n--------------------") - return EXPLORER_NAME + return self.config.explorer.name def explore_step(self) -> bool: algo_config = self.algorithm_manager.get_current_algorithm_config(self.explore_step_num + 1) diff --git a/trinity/manager/manager.py b/trinity/manager/manager.py index baaf1242c3..4af6f28685 100644 --- a/trinity/manager/manager.py +++ b/trinity/manager/manager.py @@ -14,8 +14,8 @@ class CacheManager: def __init__(self, config: Config, check_config: bool = False): self.cache_dir = config.monitor.cache_dir # type: ignore - self.explorer_meta_path = os.path.join(self.cache_dir, "explorer_meta.json") # type: ignore - self.trainer_meta_path = os.path.join(self.cache_dir, "trainer_meta.json") # type: ignore + self.explorer_meta_path = os.path.join(self.cache_dir, f"{config.explorer.name}_meta.json") # type: ignore + self.trainer_meta_path = os.path.join(self.cache_dir, f"{config.trainer.name}_meta.json") # type: ignore if check_config: self._check_config_consistency(config) diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 216c916c69..91f681e47f 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -10,12 +10,7 @@ import ray from trinity.common.config import Config -from trinity.common.constants import ( - EXPLORER_NAME, - TRAINER_NAME, - RunningStatus, - SyncMethod, -) +from trinity.common.constants import RunningStatus, SyncMethod from trinity.utils.log import get_logger @@ -45,7 +40,7 @@ def train(self) -> str: self.logger.error(f"Error in Trainer: {e}") break self.logger.info("--------------------\n> Trainer finished.\n--------------------") - return TRAINER_NAME + return self.config.trainer.name def train_step(self) -> bool: """Train one step. @@ -63,7 +58,7 @@ def sync_weight(self) -> None: """Sync the model weight.""" if self.config.synchronizer.sync_method == SyncMethod.NCCL: if self.explorer_ref is None: - self.explorer_ref = ray.get_actor(EXPLORER_NAME) + self.explorer_ref = ray.get_actor(self.config.explorer.name) explorer_status = ray.get(self.explorer_ref.running_status.remote()) if explorer_status == RunningStatus.STOPPED: self.logger.warning("Explorer has already stopped. Skipping sync weight.") diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index cbc88902a0..5e76375315 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -71,11 +71,7 @@ from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager from trinity.common.config import AlgorithmConfig -from trinity.common.constants import ( - EXPLORER_NAME, - ROLLOUT_WEIGHT_SYNC_GROUP_NAME, - SyncMethod, -) +from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, SyncMethod from trinity.utils.distributed import init_process_group, is_ipv6_address logger = logging.getLogger(__file__) @@ -577,7 +573,7 @@ def setup_weight_sync_group(self): master_address, master_port = self.get_availale_master_addr_port() world_size = self.config.synchronizer.explorer_world_size + 1 print(f"Trainer init_process_group {master_address}:{master_port} ({world_size}).") - explorer = ray.get_actor(EXPLORER_NAME) + explorer = ray.get_actor(self.config.explorer_name) setup_ref = explorer.setup_weight_sync_group.remote( master_address, master_port, self.state_dict_meta ) diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index d041bea128..7c789a98d2 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -36,7 +36,6 @@ from trinity.algorithm.algorithm_manager import AlgorithmManager from trinity.algorithm.utils import prefix_metrics from trinity.common.config import Config -from trinity.common.constants import TRAINER_NAME from trinity.common.experience import Experiences from trinity.trainer.trainer import TrainEngineWrapper from trinity.utils.monitor import MONITOR @@ -150,7 +149,7 @@ def __init__( self.logger = MONITOR.get(global_config.monitor.monitor_type)( project=config.trainer.project_name, name=config.trainer.experiment_name, - role=TRAINER_NAME, + role=global_config.trainer.name, config=global_config, ) self.reset_experiences_example_table() diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index f12a854335..965fb7e4df 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -69,7 +69,7 @@ def calculate_metrics( @MONITOR.register_module("tensorboard") class TensorboardMonitor(Monitor): def __init__(self, project: str, name: str, role: str, config: Config = None) -> None: - self.tensorboard_dir = os.path.join(config.monitor.cache_dir, "tensorboard") + self.tensorboard_dir = os.path.join(config.monitor.cache_dir, "tensorboard", role) os.makedirs(self.tensorboard_dir, exist_ok=True) self.logger = SummaryWriter(self.tensorboard_dir) self.console_logger = get_logger(__name__) From 1b97d446e415e381f3d1a4b4d109f81dfe1b7a40 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 25 Jun 2025 17:42:38 +0800 Subject: [PATCH 3/9] clean config --- .../source/tutorial/trinity_configs.md | 11 +++++---- tests/trainer/trainer_test.py | 2 +- trinity/buffer/queue.py | 23 +++++++------------ trinity/buffer/ray_wrapper.py | 8 ------- trinity/common/config.py | 4 ---- 5 files changed, 15 insertions(+), 33 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 6b9abdeb25..c66af57c0d 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -223,7 +223,7 @@ buffer: The configuration for each task dataset is defined as follows: -- `name`: Name of the dataset. Name must be unique. +- `name`: Name of the dataset. This name will be used as the Ray actor's name, so it must be unique. - `storage_type`: How the dataset is stored. Options: `file`, `queue`, `sql`. - `file`: The dataset is stored in `jsonl`/`parquet` files. The data file organization is required to meet the huggingface standard. *We recommand using this storage type for most cases.* - `queue`: The dataset is stored in a queue. The queue is a simple FIFO queue that stores the task dataset. *Do not use this storage type for task dataset unless you know what you are doing.* @@ -258,7 +258,6 @@ buffer: storage_type: queue path: sqlite:///countdown_buffer.db wrap_in_ray: True - ray_namespace: Trinity-RFT/example ``` - `name`: The name of the experience buffer. This name will be used as the Ray actor's name, so it must be unique. @@ -269,11 +268,8 @@ buffer: - For `queue` storage type, this field is optional. You can specify a SQLite database or JSON file path here to back up the queue data. - For `file` storage type, the path points to the directory containing the dataset files. - For `sql` storage type, the path points to the SQLite database file. - - `wrap_in_ray`: Whether to wrap the experience buffer in a Ray actor. Only take effect when `storage_type` is `sql` or `file`. The `queue` storage always uses a Ray actor. -- `ray_namespace`: The Ray namespace of the experience buffer. If you want to connect to an existing experience buffer launched by another experiment, set this value to the `ray_namespace` of the target experiment and provide the corresponding buffer’s `name`. Otherwise, this field can be omitted. - ### Trainer Input @@ -311,6 +307,7 @@ Controls the rollout models and workflow execution. ```yaml explorer: + name: explorer runner_num: 32 rollout_model: engine_type: vllm_async @@ -321,11 +318,13 @@ explorer: tensor_parallel_size: 1 ``` +- `name`: Name of the explorer. This name will be used as the Ray actor's name, so it must be unique. - `runner_num`: Number of parallel workflow runners. - `rollout_model.engine_type`: Type of inference engine. Options: `vllm_async` (recommended), `vllm`. - `rollout_model.engine_num`: Number of inference engines. - `rollout_model.tensor_parallel_size`: Degree of tensor parallelism. - `auxiliary_models`: Additional models used for custom workflows. + --- ## Synchronizer Configuration @@ -355,12 +354,14 @@ Specifies the backend and behavior of the trainer. ```yaml trainer: + name: trainer trainer_type: 'verl' save_interval: 100 trainer_config_path: 'examples/ppo_countdown/train_countdown.yaml' trainer_config: null ``` +- `name`: Name of the trainer. This name will be used as the Ray actor's name, so it must be unique. - `trainer_type`: Trainer backend implementation. Currently only supports `verl`. - `save_interval`: Frequency (in steps) at which to save model checkpoints. - `trainer_config_path`: The path to the trainer configuration file. diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 21a1b9f1b2..14e365c6c0 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -297,7 +297,7 @@ def test_fully_async_mode(self): wrap_in_ray=True, ) config.synchronizer.sync_method = SyncMethod.CHECKPOINT - config.synchronizer.sync_interval = 4 + config.synchronizer.sync_interval = 8 config.monitor.monitor_type = "tensorboard" trainer_config = deepcopy(config) trainer_config.mode = "train" diff --git a/trinity/buffer/queue.py b/trinity/buffer/queue.py index 3e6ba4f329..9cdd99a592 100644 --- a/trinity/buffer/queue.py +++ b/trinity/buffer/queue.py @@ -85,19 +85,12 @@ async def get_batch(self, batch_size: int) -> List: @classmethod def get_actor(cls, storage_config: StorageConfig, config: BufferConfig): """Get the queue actor.""" - if storage_config.ray_namespace: - queue_actor = ray.get_actor( - f"queue-{storage_config.name}", - namespace=storage_config.ray_namespace, + return ( + ray.remote(cls) + .options( + name=f"queue-{storage_config.name}", + namespace=ray.get_runtime_context().namespace, + get_if_exists=True, ) - else: - queue_actor = ( - ray.remote(cls) - .options( - name=f"queue-{storage_config.name}", - namespace=ray.get_runtime_context().namespace, - get_if_exists=True, - ) - .remote(storage_config, config) - ) - return queue_actor + .remote(storage_config, config) + ) diff --git a/trinity/buffer/ray_wrapper.py b/trinity/buffer/ray_wrapper.py index 4e3e4baf50..ba736b02bc 100644 --- a/trinity/buffer/ray_wrapper.py +++ b/trinity/buffer/ray_wrapper.py @@ -53,10 +53,6 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: @classmethod def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): if storage_config.wrap_in_ray: - if storage_config.ray_namespace: - return ray.get_actor( - f"sql-{storage_config.name}", namespace=storage_config.ray_namespace - ) return ( ray.remote(cls) .options( @@ -171,10 +167,6 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: @classmethod def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): if storage_config.wrap_in_ray: - if storage_config.ray_namespace: - return ray.get_actor( - f"json-{storage_config.name}", namespace=storage_config.ray_namespace - ) return ( ray.remote(cls) .options( diff --git a/trinity/common/config.py b/trinity/common/config.py index 6f7ca59271..e6130395b2 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -87,10 +87,6 @@ class StorageConfig: # used for StorageType.QUEUE capacity: int = 10000 - # used in Fully Async Mode, - # the explorer started later can connect to an existing buffer in different namespace - ray_namespace: Optional[str] = None - # used for rollout tasks default_workflow_type: Optional[str] = None default_reward_fn_type: Optional[str] = None From a889c103607bc795871ffb41de5ae75fbc393b32 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 25 Jun 2025 17:54:15 +0800 Subject: [PATCH 4/9] fix checkpoint sync --- trinity/common/models/vllm_async_model.py | 2 +- trinity/common/models/vllm_model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py index 101d04764b..0806bc9c7d 100644 --- a/trinity/common/models/vllm_async_model.py +++ b/trinity/common/models/vllm_async_model.py @@ -300,8 +300,8 @@ async def init_process_group( timeout, update_with_checkpoint, state_dict_meta, - ray.get_runtime_context().namespace, explorer_name, + ray.get_runtime_context().namespace, ), ) diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 1e154da3a2..59211f198a 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -114,8 +114,8 @@ def init_process_group( timeout, update_with_checkpoint, state_dict_meta, - ray.get_runtime_context().namespace, explorer_name, + ray.get_runtime_context().namespace, ), ) From c173cd8c8a072e3d828971a60174502f5a6780dc Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 25 Jun 2025 18:22:50 +0800 Subject: [PATCH 5/9] fix tests --- tests/explorer/runner_pool_test.py | 2 +- tests/trainer/trainer_test.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/explorer/runner_pool_test.py b/tests/explorer/runner_pool_test.py index 52f961bda4..8e9247344a 100644 --- a/tests/explorer/runner_pool_test.py +++ b/tests/explorer/runner_pool_test.py @@ -253,6 +253,6 @@ def test_runner_pool_with_auxiliary_models(self): st = time.time() status = pool.get_next_unorder() et = time.time() - self.assertTrue(et - st < 1) + self.assertTrue(et - st < 1.5) self.assertEqual(len(status), 1) self.assertTrue(status[0].ok) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 14e365c6c0..a7c1eed157 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -321,8 +321,6 @@ def test_fully_async_mode(self): import multiprocessing - multiprocessing.set_start_method("spawn") - trainer_process = multiprocessing.Process(target=run_trainer, args=(trainer_config,)) trainer_process.start() From e02c5a39b0acf3faf36c8e8894c6aeae6cd8c096 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 25 Jun 2025 19:15:13 +0800 Subject: [PATCH 6/9] fix multiprocessing init method --- tests/trainer/trainer_test.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index a7c1eed157..250ea3eb40 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -1,4 +1,5 @@ """Tests for trainer.""" +import multiprocessing import os import shutil import time @@ -280,6 +281,10 @@ def run_explorer(config: Config) -> None: class TestFullyAsyncMode(unittest.TestCase): + def setUp(self): + if multiprocessing.get_start_method(allow_none=True) != "spawn": + multiprocessing.set_start_method("spawn", force=True) + def test_fully_async_mode(self): config = get_template_config() config.project = "unittest" @@ -319,8 +324,6 @@ def test_fully_async_mode(self): explorer2_config = deepcopy(explorer1_config) explorer1_config.check_and_update() - import multiprocessing - trainer_process = multiprocessing.Process(target=run_trainer, args=(trainer_config,)) trainer_process.start() From d924e9b00620bb71a2aaae8236c4992ab8975057 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 25 Jun 2025 20:08:12 +0800 Subject: [PATCH 7/9] add docs for async mode --- .../source/tutorial/example_async_mode.md | 80 ++++++++++++++++--- tests/explorer/runner_pool_test.py | 3 - 2 files changed, 68 insertions(+), 15 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/example_async_mode.md b/docs/sphinx_doc/source/tutorial/example_async_mode.md index 1f9a9c8665..70ca66e2b2 100644 --- a/docs/sphinx_doc/source/tutorial/example_async_mode.md +++ b/docs/sphinx_doc/source/tutorial/example_async_mode.md @@ -1,17 +1,17 @@ # Asynchronous RFT -This example shows how to run RFT in a fully asynchronous mode with the GRPO algorithm, Qwen2.5-1.5B-Instruct model and GSM8K dataset. +This example demonstrates how to run RFT in fully asynchronous mode using the GRPO algorithm, Qwen2.5-1.5B-Instruct model, and GSM8K dataset. -Trinity-RFT supports an asynchronous mode by running the trainer and explorer in separate processes. +Trinity-RFT supports Asynchronous RFT by running the trainer and explorer in separate processes. -For this purpose, we prepare two main config files: [`explorer.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/async_gsm8k/explorer.yaml) and [`trainer.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/async_gsm8k/trainer.yaml). -The main difference between them is that in `explorer.yaml` we set `mode` as `explore`, while in `trainer.yaml` we set `mode` as `train`. +For this purpose, we provide two main configuration files: [`explorer.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/async_gsm8k/explorer.yaml) and [`trainer.yaml`](https://github.com/modelscope/Trinity-RFT/blob/main/examples/async_gsm8k/trainer.yaml). +The primary difference between them is that in `explorer.yaml` we set `mode` as `explore`, while in `trainer.yaml` we set `mode` as `train`. The model weights of the explorer and trainer are synchronized once every `sync_interval * batch_size` tasks. -Suppose we have a node of 8 GPUs; we use 4 GPUs for the trainer and 4 GPUs for the explorer. -Some important setups of `explorer.yaml` are listed in the following: +Assuming we have a node with 8 GPUs, we allocate 4 GPUs for the trainer and 4 GPUs for the explorer. Key configurations in `explorer.yaml` are as follows: ```yaml +# explorer.yaml project: name: mode: explore @@ -26,7 +26,7 @@ cluster: gpu_per_node: 4 buffer: total_epochs: 1 - batch_size: 96 + batch_size: 64 explorer_input: taskset: name: gsm8k @@ -45,7 +45,6 @@ buffer: storage_type: queue path: 'sqlite:///gsm8k.db' explorer: - eval_interval: 10 runner_num: 32 rollout_model: engine_type: vllm_async @@ -57,9 +56,10 @@ trainer: trainer_config_path: examples/async_gsm8k/verl_config.yaml ``` -Some important setups of `trainer.yaml` are listed in the following: +Key configurations in `trainer.yaml` are as follows: ```yaml +# trainer.yaml project: name: mode: train @@ -74,7 +74,7 @@ cluster: gpu_per_node: 4 buffer: total_epochs: 1 - batch_size: 96 + batch_size: 64 explorer_input: taskset: name: gsm8k @@ -98,8 +98,7 @@ trainer: trainer_config_path: examples/async_gsm8k/verl_config.yaml ``` - -You may run this example with the following command: +You can run this example with the following command: ```bash bash examples/async_gsm8k/run.sh @@ -110,3 +109,60 @@ The following plot shows the learning curve of GRPO in the asynchronous mode. > We are continuously investigating other RL algorithms (e.g., [OPMD](./example_reasoning_advanced.md)) in the asynchronous mode. ![async](../../assets/async-curve.png) + + +Trinity-RFT also supports dynamic scaling in asynchronous mode. Continuing with the previous example, if an additional machine with 8 GPUs joins the Ray cluster during training, you can launch a new explorer using the following configuration `explorer_new.yaml`. + +```yaml +# explorer_new.yaml +project: +name: +mode: explore +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +algorithm: + algorithm_type: grpo + repeat_times: 8 +model: + model_path: /PATH/TO/MODEL/ +cluster: # important + node_num: 1 + gpu_per_node: 8 +explorer: + name: 'explorer_new' # important + runner_num: 64 + rollout_model: + engine_type: vllm_async + engine_num: 8 +buffer: + total_epochs: 1 + batch_size: 64 + explorer_input: + taskset: # important + name: gsm8k + storage_type: file + path: /PATH/TO/DATASET/ + format: + prompt_key: 'question' + response_key: 'answer' + rollout_args: + temperature: 1.0 + default_workflow_type: 'math_workflow' + trainer_input: + experience_buffer: + name: gsm8k_buffer + storage_type: queue + path: 'sqlite:///gsm8k.db' +synchronizer: + sync_method: 'checkpoint' + sync_interval: 10 +# other configs are the same as explorer.yaml +``` + +The differences between `explorer_new.yaml` and `explorer.yaml` include: + +- `cluster.node_num/gpu_per_node`: Specify the cluster configuration for the newly added explorer. +- `explorer.name`: The later-started explorer requires a different name than "explorer", which is the default name for the existing explorer. +- `explorer.rollout_model.engine_num/tensor_parallel_size`: Define the engine number and tensor parallel size to optimally utilize GPU resources. +- `buffer.explorer_input.taskset`: Provide another task dataset as input for the new explorer. + +All other parameters remain the same as in `explorer.yaml`. diff --git a/tests/explorer/runner_pool_test.py b/tests/explorer/runner_pool_test.py index 8e9247344a..735255ecf2 100644 --- a/tests/explorer/runner_pool_test.py +++ b/tests/explorer/runner_pool_test.py @@ -250,9 +250,6 @@ def test_runner_pool_with_auxiliary_models(self): ) # `auxiliary_models` - st = time.time() status = pool.get_next_unorder() - et = time.time() - self.assertTrue(et - st < 1.5) self.assertEqual(len(status), 1) self.assertTrue(status[0].ok) From 815e6e086029f298cc997d87e316ae88d413b6af Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 25 Jun 2025 20:10:46 +0800 Subject: [PATCH 8/9] fix comments --- docs/sphinx_doc/source/tutorial/trinity_configs.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index c66af57c0d..88d925f786 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -246,7 +246,7 @@ The configuration for each task dataset is defined as follows: ### Explorer Output -In [`explore` mode](#global-configuration), since there is no trainer, users can configure an experience buffer via `buffer.explorer_input`, rather than using `buffer.trainer_input`, which will be introduced in the next section. +In [`explore` mode](#global-configuration), since there is no trainer, users can configure an experience buffer via `buffer.explorer_output`, rather than using `buffer.trainer_input`, which will be introduced in the next section. > For `both` and `train` modes, users should use `buffer.trainer_input` instead of `buffer.explorer_output`. @@ -265,6 +265,7 @@ buffer: - `queue`: Experience data is stored in a queue. This storage type is recommended for most use cases. - `sql`: Experience data is stored in a SQL database. If your database only supports local access (e.g., SQLite), set `wrap_in_ray` to `True` to wrap the database in a Ray actor, enabling remote access from other nodes. - `file`: Experience data is stored in a JSON file. This storage type should be used only for debugging purposes in `explore` mode. +- `path`: The path to the experience buffer. - For `queue` storage type, this field is optional. You can specify a SQLite database or JSON file path here to back up the queue data. - For `file` storage type, the path points to the directory containing the dataset files. - For `sql` storage type, the path points to the SQLite database file. From da34830630680ea9de395eb19d585566407a651b Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 25 Jun 2025 20:32:27 +0800 Subject: [PATCH 9/9] fix typo --- docs/sphinx_doc/source/tutorial/trinity_programming_guide.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md index e07e6bb3dc..fb75d084b1 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md +++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md @@ -2,9 +2,8 @@ This guide introduces how to develop new modules in Trinity-RFT and provides relevant development guidelines. -Trinity-RFT consists of three main modules: **Explorer**, **Trainer** and **Buffer**. -We decouple the RL pipeline into three modules to make it easier to customize and extend. -Below is a table summarizing the modules and components that developers with different tragets need to focus on. +In Trinity-RFT, we decompose the RL pipeline into three main modules (**Explorer**, **Trainer** and **Buffer**) to facilitate customization and extension. +Below is a table summarizing the modules and components that developers with different targets need to focus on. | Development Target | Core Module | Key Component | |--------------------|-------------|---------------|