diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 87c2c95941..f34f1dffd4 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -275,6 +275,8 @@ buffer: - For `sql` storage type, the path points to the SQLite database file. - `wrap_in_ray`: Whether to wrap the experience buffer in a Ray actor. Only take effect when `storage_type` is `sql` or `file`. The `queue` storage always uses a Ray actor. - `max_read_timeout`: The maximum waiting time (in seconds) to read new experience data. If exceeded, an incomplete batch will be returned directly. Only take effect when `storage_type` is `queue`. Default is 1800 seconds (30 minutes). +- `use_priority_queue`: Only take effect when `storage_type` is `queue`. If set to `True`, the queue will be a priority queue, which allows for prioritizing certain experiences over others. Default is `False`. +- `reuse_cooldown_time`: Only take effect when `storage_type` is `queue` and `use_priority_queue` is `True`. If set, it specifies the cooldown time (in seconds) for reusing experiences. If not specified, the default value is `None`, meaning experiences can not be reused. ### Trainer Input diff --git a/pyproject.toml b/pyproject.toml index fd7ca1f3c0..654e54827c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "tensorboard", "openai", "jsonlines", + "sortedcontainers", ] [project.scripts] @@ -61,6 +62,7 @@ dev = [ "mypy>=1.7.0", "pytest>=8.0.0", "pytest-json-ctrf", + "parameterized", ] doc = [ diff --git a/scripts/docker/Dockerfile b/scripts/docker/Dockerfile index 5bf2b4a2e9..756ea61c42 100644 --- a/scripts/docker/Dockerfile +++ b/scripts/docker/Dockerfile @@ -20,7 +20,7 @@ RUN apt update && apt install -y \ # For Aliyun users: update pip mirror to aliyun to speed up pip install RUN pip config set global.index-url http://mirrors.cloud.aliyuncs.com/pypi/simple/ \ - && pip config set global.trusted-host mirrors.cloud.aliyuncs.com + && pip config set install.trusted-host mirrors.cloud.aliyuncs.com # copy the Trinity-RFT dir into the workspace COPY . . diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py index 81da20437f..702e8b8ca1 100644 --- a/tests/buffer/queue_test.py +++ b/tests/buffer/queue_test.py @@ -3,6 +3,7 @@ import time import torch +from parameterized import parameterized from tests.tools import RayUnittestBaseAysnc from trinity.buffer.reader.queue_reader import QueueReader @@ -15,24 +16,29 @@ class TestQueueBuffer(RayUnittestBaseAysnc): - async def test_queue_buffer(self): - total_num = 8 - put_batch_size = 2 - read_batch_size = 4 + @parameterized.expand( + [ + ( + "queue", + False, + ), + ( + "priority_queue", + True, + ), + ] + ) + async def test_queue_buffer(self, name, use_priority_queue): meta = StorageConfig( name="test_buffer", algorithm_type="ppo", storage_type=StorageType.QUEUE, max_read_timeout=3, path=BUFFER_FILE_PATH, + use_priority_queue=use_priority_queue, ) - config = BufferConfig( - max_retry_times=3, - max_retry_interval=1, - read_batch_size=read_batch_size, - ) - writer = QueueWriter(meta, config) - reader = QueueReader(meta, config) + writer = QueueWriter(meta, self.config) + reader = QueueReader(meta, self.config) self.assertEqual(await writer.acquire(), 1) exps = [ Experience( @@ -41,37 +47,76 @@ async def test_queue_buffer(self): reward=float(i), logprobs=torch.tensor([0.1]), ) - for i in range(1, put_batch_size + 1) + for i in range(1, self.put_batch_size + 1) ] - for _ in range(total_num // put_batch_size): + for exp in exps: + exp.info = {"model_version": 0, "use_count": 0} + for _ in range(self.total_num // self.put_batch_size): await writer.write_async(exps) - for _ in range(total_num // read_batch_size): + for _ in range(self.total_num // self.read_batch_size): exps = reader.read() - self.assertEqual(len(exps), read_batch_size) - print(f"finish read {read_batch_size} experience") - writer.write( - [ - Experience( - tokens=torch.tensor([float(j) for j in range(i + 1)]), - prompt_length=i, - reward=float(i), - logprobs=torch.tensor([0.1]), - action_mask=torch.tensor([j % 2 for j in range(i + 1)]), - ) - for i in range(1, put_batch_size * 2 + 1) - ] - ) - exps = reader.read(batch_size=put_batch_size * 2) - self.assertEqual(len(exps), put_batch_size * 2) + self.assertEqual(len(exps), self.read_batch_size) + print(f"finish read {self.read_batch_size} experience") + exps = [ + Experience( + tokens=torch.tensor([float(j) for j in range(i + 1)]), + prompt_length=i, + reward=float(i), + logprobs=torch.tensor([0.1]), + action_mask=torch.tensor([j % 2 for j in range(i + 1)]), + ) + for i in range(1, self.put_batch_size * 2 + 1) + ] + for exp in exps: + exp.info = {"model_version": 1, "use_count": 0} + writer.write(exps) + exps = reader.read(batch_size=self.put_batch_size * 2) + self.assertEqual(len(exps), self.put_batch_size * 2) self.assertEqual(await writer.release(), 0) self.assertRaises(StopIteration, reader.read) with open(BUFFER_FILE_PATH, "r") as f: - self.assertEqual(len(f.readlines()), total_num + put_batch_size * 2) + self.assertEqual(len(f.readlines()), self.total_num + self.put_batch_size * 2) st = time.time() self.assertRaises(TimeoutError, reader.read, batch_size=1) et = time.time() self.assertTrue(et - st > 2) + async def test_priority_queue_capacity(self): + # test queue capacity + meta = StorageConfig( + name="test_buffer_small", + algorithm_type="ppo", + storage_type=StorageType.QUEUE, + max_read_timeout=1, + capacity=2, + path=BUFFER_FILE_PATH, + use_priority_queue=True, + replay_buffer_kwargs={"priority_fn": "linear_decay", "decay": 0.6}, + ) + writer = QueueWriter(meta, self.config) + reader = QueueReader(meta, self.config) + + for i in range(4): + writer.write( + [ + Experience( + tokens=torch.tensor([1, 2, 3]), + prompt_length=2, + info={"model_version": i, "use_count": 0}, + ), + ] + ) + + exps = reader.read(batch_size=2) + self.assertEqual(exps[0].info["model_version"], 3) + self.assertEqual(exps[0].info["use_count"], 1) + self.assertEqual(exps[1].info["model_version"], 2) + self.assertEqual(exps[1].info["use_count"], 1) + + with self.assertRaises(TimeoutError): + reader.read(batch_size=1) + + async def test_queue_buffer_capacity(self): # test queue capacity meta = StorageConfig( name="test_buffer_small", @@ -81,8 +126,8 @@ async def test_queue_buffer(self): capacity=4, path=BUFFER_FILE_PATH, ) - writer = QueueWriter(meta, config) - reader = QueueReader(meta, config) + writer = QueueWriter(meta, self.config) + reader = QueueReader(meta, self.config) writer.write([{"content": "hello"}]) writer.write([{"content": "hi"}]) writer.write([{"content": "hello"}]) @@ -100,6 +145,139 @@ def write_blocking_call(): thread.join(timeout=1) self.assertFalse(thread.is_alive()) + async def test_priority_queue_buffer_reuse(self): + # test queue reuse + meta = StorageConfig( + name="test_buffer_small", + algorithm_type="ppo", + storage_type=StorageType.QUEUE, + max_read_timeout=3, + capacity=4, + path=BUFFER_FILE_PATH, + use_priority_queue=True, + reuse_cooldown_time=0.5, + replay_buffer_kwargs={"priority_fn": "linear_decay", "decay": 0.6}, + ) + writer = QueueWriter(meta, self.config) + reader = QueueReader(meta, self.config) + for i in range(4): + writer.write( + [ + Experience( + tokens=torch.tensor([1, 2, 3]), + prompt_length=2, + info={"model_version": i, "use_count": 0}, + ), + Experience( + tokens=torch.tensor([1, 2, 3]), + prompt_length=2, + info={"model_version": i, "use_count": 0}, + ), + ] + ) + + # should not be blocked + def replace_call(): + writer.write( + [ + Experience( + tokens=torch.tensor([1, 2, 3]), + prompt_length=2, + info={"model_version": 4, "use_count": 0}, + ), + Experience( + tokens=torch.tensor([1, 2, 3]), + prompt_length=2, + info={"model_version": 4, "use_count": 0}, + ), + ] + ) + + thread = threading.Thread(target=replace_call) + thread.start() + thread.join(timeout=2) + self.assertFalse(thread.is_alive()) + + exps = reader.read(batch_size=4) + self.assertEqual(len(exps), 4) + self.assertEqual(exps[0].info["model_version"], 4) + self.assertEqual(exps[0].info["use_count"], 1) + self.assertEqual(exps[2].info["model_version"], 3) + self.assertEqual(exps[2].info["use_count"], 1) + + # model_version 4, 3, 2, 1 + # use_count 1, 1, 0, 0 + # priority 3.4, 2.4, 2.0, 1.0 + + time.sleep(1) + exps = reader.read(batch_size=4) + self.assertEqual(len(exps), 4) + self.assertEqual(exps[0].info["model_version"], 4) + self.assertEqual(exps[0].info["use_count"], 2) + self.assertEqual(exps[2].info["model_version"], 3) + self.assertEqual(exps[2].info["use_count"], 2) + + # model_version 4, 3, 2, 1 + # use_count 2, 2, 0, 0 + # priority 2.8, 1.8, 2.0, 1.0 + + time.sleep(1) + exps = reader.read(batch_size=4) + self.assertEqual(len(exps), 4) + self.assertEqual(exps[0].info["model_version"], 4) + self.assertEqual(exps[0].info["use_count"], 3) + self.assertEqual(exps[2].info["model_version"], 2) + self.assertEqual(exps[2].info["use_count"], 1) + + # model_version 4, 3, 2, 1 + # use_count 3, 2, 1, 0 + # priority 2.2, 1.8, 1.4, 1.0 + + time.sleep(1) + exps = reader.read(batch_size=4) + self.assertEqual(len(exps), 4) + self.assertEqual(exps[0].info["model_version"], 4) + self.assertEqual(exps[0].info["use_count"], 4) + self.assertEqual(exps[2].info["model_version"], 3) + self.assertEqual(exps[2].info["use_count"], 3) + + # model_version 4, 3, 2, 1 + # use_count 4, 3, 1, 0 + # priority 1.6, 1.2, 1.4, 1.0 + + time.sleep(1) + exps = reader.read(batch_size=4) + self.assertEqual(len(exps), 4) + self.assertEqual(exps[0].info["model_version"], 4) + self.assertEqual(exps[0].info["use_count"], 5) + self.assertEqual(exps[2].info["model_version"], 2) + self.assertEqual(exps[2].info["use_count"], 2) + + # model_version 4, 3, 2, 1 + # use_count 5, 3, 2, 0 + # priority 1.0, 1.2, 0.8, 1.0 + + time.sleep(1) + exps = reader.read(batch_size=4) + self.assertEqual(len(exps), 4) + self.assertEqual(exps[0].info["model_version"], 3) + self.assertEqual(exps[0].info["use_count"], 4) + self.assertEqual(exps[2].info["model_version"], 1) + self.assertEqual(exps[2].info["use_count"], 1) + + # model_version 4, 3, 2, 1 + # use_count 5, 4, 2, 1 + # priority 1.0, 0.6, 0.8, 0.4 + def setUp(self): + self.total_num = 8 + self.put_batch_size = 2 + self.read_batch_size = 4 + + self.config = BufferConfig( + max_retry_times=3, + max_retry_interval=1, + read_batch_size=self.read_batch_size, + ) if os.path.exists(BUFFER_FILE_PATH): os.remove(BUFFER_FILE_PATH) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index cb125ac2b5..a697539141 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -10,6 +10,7 @@ from datetime import datetime import ray +from parameterized import parameterized from tests.tools import ( RayUnittestBase, @@ -301,7 +302,19 @@ def setUp(self): if multiprocessing.get_start_method(allow_none=True) != "spawn": multiprocessing.set_start_method("spawn", force=True) - def test_fully_async_mode(self): + @parameterized.expand( + [ + ( + "queue", + False, + ), + ( + "priority_queue", + True, + ), + ] + ) + def test_fully_async_mode(self, name, use_priority_queue): config = get_template_config() config.project = "unittest" config.name = f"fully_async_{datetime.now().strftime('%Y%m%d%H%M%S')}" @@ -316,6 +329,7 @@ def test_fully_async_mode(self): name="exp_buffer", storage_type=StorageType.QUEUE, wrap_in_ray=True, + use_priority_queue=use_priority_queue, ) config.synchronizer.sync_method = SyncMethod.CHECKPOINT config.synchronizer.sync_interval = 8 diff --git a/trinity/buffer/priority_queue.py b/trinity/buffer/priority_queue.py new file mode 100644 index 0000000000..7878a3e28e --- /dev/null +++ b/trinity/buffer/priority_queue.py @@ -0,0 +1,126 @@ +"""An Async PriorityQueue.""" +import asyncio +from collections import deque +from typing import List, Optional, Union + +import numpy as np +from sortedcontainers import SortedDict + +from trinity.common.experience import Experience +from trinity.utils.registry import Registry + +PRIORITY_FUNC = Registry("priority_fn") + + +@PRIORITY_FUNC.register_module("linear_decay") +def linear_decay_priority(item: List[Experience], decay: float = 0.1): + return item[0].info["model_version"] - decay * item[0].info["use_count"] # type: ignore + + +class AsyncPriorityQueue: + """ + An asynchronous priority queue that manages a fixed-size buffer of experience items. + Items are prioritized using a user-defined function and reinserted after a cooldown period. + + Attributes: + capacity (int): Maximum number of items the queue can hold. + priority_groups (SortedDict): Maps priorities to deques of items with the same priority. + priority_fn (callable): Function used to determine the priority of an item. + reuse_cooldown_time (float): Delay before reusing an item (set to infinity to disable). + """ + + def __init__( + self, + capacity: int, + reuse_cooldown_time: Optional[float] = None, + priority_fn: str = "linear_decay", + **kwargs, + ): + """ + Initialize the async priority queue. + + Args: + capacity (`int`): The maximum number of items the queue can store. + reuse_cooldown_time (`float`): Time to wait before reusing an item. Set to None to disable reuse. + priority_fn (`str`): Name of the function to use for determining item priority. + kwargs: Additional keyword arguments for the priority function. + """ + self.capacity = capacity + self.priority_groups = SortedDict() # Maps priority -> deque of items + priority_fn = PRIORITY_FUNC.get(priority_fn) + from trinity.buffer.queue import QueueActor + + # TODO: remove FINISHE_MESSAGE and use a more elegant solution + self.FINISH_MESSAGE = QueueActor.FINISH_MESSAGE + + self.priority_fn = ( + lambda item: priority_fn(item, **kwargs) if item != self.FINISH_MESSAGE else -np.inf # type: ignore + ) + self.reuse_cooldown_time = reuse_cooldown_time + self._condition = asyncio.Condition() # For thread-safe operations + + async def put(self, item: Union[List[Experience], str], delay: float = 0) -> None: + """ + Insert an item into the queue, possibly replacing the lowest-priority item if full. + + Args: + item (`List[Experience]`): A list of experiences to add. + delay (`float`): Optional delay before insertion (for simulating timing behavior). + """ + if delay > 0: + await asyncio.sleep(delay) + + priority = self.priority_fn(item) + async with self._condition: + if len(self.priority_groups) == self.capacity: + # If full, only insert if new item has higher or equal priority than the lowest + lowest_priority, item_queue = self.priority_groups.peekitem(index=0) + if lowest_priority > priority: + return # Skip insertion if lower priority + # Remove the lowest priority item + item_queue.popleft() + if not item_queue: + self.priority_groups.popitem(index=0) + + # Add the new item + if priority not in self.priority_groups: + self.priority_groups[priority] = deque() + self.priority_groups[priority].append(item) + self._condition.notify() + + async def get(self) -> List[Experience]: + """ + Retrieve the highest-priority item from the queue. + + Returns: + List[Experience]: The highest-priority item (list of experiences). + + Notes: + - After retrieval, the item is optionally reinserted after a cooldown period. + """ + async with self._condition: + while len(self.priority_groups) == 0: + await self._condition.wait() + + _, item_queue = self.priority_groups.peekitem(index=-1) + item = item_queue.popleft() + if not item_queue: + self.priority_groups.popitem(index=-1) + + if item != self.FINISH_MESSAGE: + for exp in item: + exp.info["use_count"] += 1 + # Optionally resubmit the item after a cooldown + if self.reuse_cooldown_time is not None: + asyncio.create_task(self.put(item, self.reuse_cooldown_time)) + + return item + + def size(self) -> int: + """ + Get the current number of items in the queue. + + Returns: + int: Number of items currently stored. + """ + return len(self.priority_groups) diff --git a/trinity/buffer/queue.py b/trinity/buffer/queue.py index 534283c50b..c2d45029b3 100644 --- a/trinity/buffer/queue.py +++ b/trinity/buffer/queue.py @@ -5,6 +5,7 @@ import ray +from trinity.buffer.priority_queue import AsyncPriorityQueue from trinity.buffer.writer.file_writer import JSONWriter from trinity.buffer.writer.sql_writer import SQLWriter from trinity.common.config import BufferConfig, StorageConfig @@ -29,7 +30,14 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: self.logger = get_logger(__name__) self.config = config self.capacity = storage_config.capacity - self.queue = asyncio.Queue(self.capacity) + if storage_config.use_priority_queue: + reuse_cooldown_time = storage_config.reuse_cooldown_time + replay_buffer_kwargs = storage_config.replay_buffer_kwargs + self.queue = AsyncPriorityQueue( + self.capacity, reuse_cooldown_time, **replay_buffer_kwargs + ) + else: + self.queue = asyncio.Queue(self.capacity) st_config = deepcopy(storage_config) st_config.wrap_in_ray = False if st_config.path is not None: diff --git a/trinity/common/config.py b/trinity/common/config.py index 594375a1b2..87280cdbc8 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -88,6 +88,11 @@ class StorageConfig: # used for StorageType.QUEUE capacity: int = 10000 max_read_timeout: float = 1800 + use_priority_queue: bool = False + reuse_cooldown_time: Optional[float] = None + replay_buffer_kwargs: dict = field( + default_factory=lambda: {"priority_fn": "linear_decay", "decay": 0.1} + ) # used for rollout tasks default_workflow_type: Optional[str] = None diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index dee2de1508..e7eec9240f 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -84,6 +84,7 @@ def run_task(self, task: Task) -> Status: if not hasattr(exp, "info") or exp.info is None: exp.info = {} exp.info["model_version"] = self.model_wrapper.model_version + exp.info["use_count"] = 0 if not hasattr(exp, "metrics") or exp.metrics is None: exp.metrics = {} diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index b468382300..cabdd27987 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -191,6 +191,8 @@ def _expert_buffer_part(self): with st.expander("Experiences Buffer Configs", expanded=True): self.get_configs("storage_type") self.get_configs("experience_buffer_path") + self.get_configs("use_priority_queue") + self.get_configs("reuse_cooldown_time", "priority_fn", "priority_decay") self.buffer_advanced_tab = st.expander("Advanced Config") with self.buffer_advanced_tab: @@ -343,7 +345,6 @@ def _generate_verl_config(self): trainer_config = { "actor_rollout_ref": { - "hybrid_engine": True, "model": { "external_lib": None, "override_config": {}, @@ -352,7 +353,6 @@ def _generate_verl_config(self): }, "actor": { "strategy": st.session_state["training_strategy"], - "ppo_mini_batch_size": st.session_state["train_batch_size"], "ppo_micro_batch_size_per_gpu": st.session_state[ "actor_ppo_micro_batch_size_per_gpu" ], @@ -498,6 +498,14 @@ def _gen_buffer_config(self): "max_retry_times": st.session_state["buffer_max_retry_times"], "max_retry_interval": st.session_state["max_retry_interval"], } + if st.session_state["algorithm_type"] != "dpo": + experience_buffer = buffer_config["trainer_input"]["experience_buffer"] + experience_buffer["use_priority_queue"] = st.session_state["use_priority_queue"] + experience_buffer["reuse_cooldown_time"] = st.session_state["reuse_cooldown_time"] + experience_buffer["replay_buffer_kwargs"] = { + "priority_fn": st.session_state["priority_fn"], + "decay": st.session_state["priority_decay"], + } if st.session_state["mode"] != "train": buffer_config["explorer_input"] = { diff --git a/trinity/manager/config_registry/buffer_config_manager.py b/trinity/manager/config_registry/buffer_config_manager.py index f704d0ecd2..f7a1893f50 100644 --- a/trinity/manager/config_registry/buffer_config_manager.py +++ b/trinity/manager/config_registry/buffer_config_manager.py @@ -1,5 +1,6 @@ import streamlit as st +from trinity.buffer.priority_queue import PRIORITY_FUNC from trinity.common.constants import PromptType, StorageType from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS from trinity.common.workflows.workflow import WORKFLOWS @@ -269,7 +270,7 @@ def set_storage_type(**kwargs): storage_candidates = [StorageType.FILE.value, StorageType.SQL.value] else: st.session_state[key] = st.session_state["_not_dpo_storage_type"] - storage_candidates = [StorageType.QUEUE.value, StorageType.SQL.value] + storage_candidates = [StorageType.QUEUE.value] def on_change(): if st.session_state["algorithm_type"] == "dpo": @@ -285,6 +286,47 @@ def on_change(): ) +@CONFIG_GENERATORS.register_config(default_value=False) +def set_use_priority_queue(**kwargs): + st.checkbox("Use Priority Queue", **kwargs) + + +@CONFIG_GENERATORS.register_config( + default_value=None, visible=lambda: st.session_state["use_priority_queue"] +) +def set_reuse_cooldown_time(**kwargs): + st.number_input( + "Reuse Cooldown Time", + min_value=0.0, + max_value=1e5, + help="Leave blank to indicate no reuse", + placeholder=None, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value="linear_decay", visible=lambda: st.session_state["use_priority_queue"] +) +def set_priority_fn(**kwargs): + candidates = list(PRIORITY_FUNC.modules.keys()) + st.selectbox( + "Priority Function", + candidates, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=0.1, visible=lambda: st.session_state["use_priority_queue"] +) +def set_priority_decay(**kwargs): + st.number_input( + "Priority Decay", + **kwargs, + ) + + @CONFIG_GENERATORS.register_config( default_value="", other_configs={ @@ -307,11 +349,7 @@ def set_experience_buffer_path(**kwargs): # TODO else: st.session_state[key] = st.session_state["_not_dpo_experience_buffer_path"] title = "Experience Buffer Path" - help_msg = r"""This path is used for `trainer`, - -if `storage_type == StorageType.QUEUE`, default to `None`, - -if `storage_type == StorageType.SQL`, default to `sqlite:///{os.path.join(checkpoint_root_dir, '.cache', project_name, experiment_name)}/data.db`.""" + help_msg = r"""This path is used for experiences persistent storage, default to `None`.""" def on_change(): if st.session_state["algorithm_type"] == "dpo":