From 29bcb9f442774ca11b0cda53dc65cc5282a98338 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 23 Oct 2025 14:52:40 +0800 Subject: [PATCH 01/21] simplify reader writer get function --- tests/buffer/file_test.py | 9 +- tests/buffer/sql_test.py | 10 +- tests/cli/launcher_test.py | 10 ++ .../sample_strategy/mix_sample_strategy.py | 1 - trinity/buffer/buffer.py | 22 ++-- .../buffer/pipelines/experience_pipeline.py | 16 +-- trinity/buffer/reader/file_reader.py | 48 ++++---- trinity/buffer/reader/queue_reader.py | 10 +- trinity/buffer/reader/sql_reader.py | 10 +- trinity/buffer/schema/formatter.py | 3 +- trinity/buffer/storage/file.py | 30 +++-- trinity/buffer/storage/queue.py | 38 +++---- trinity/buffer/storage/sql.py | 94 ++++++++-------- trinity/buffer/utils.py | 25 ++--- trinity/buffer/writer/file_writer.py | 10 +- trinity/buffer/writer/queue_writer.py | 9 +- trinity/buffer/writer/sql_writer.py | 10 +- trinity/common/config.py | 103 +++++++++++++++++- trinity/common/constants.py | 1 + trinity/explorer/explorer.py | 2 +- trinity/explorer/workflow_runner.py | 2 +- 21 files changed, 278 insertions(+), 185 deletions(-) diff --git a/tests/buffer/file_test.py b/tests/buffer/file_test.py index 00aec40744..6fc2868b5c 100644 --- a/tests/buffer/file_test.py +++ b/tests/buffer/file_test.py @@ -110,7 +110,8 @@ async def test_file_writer(self): file_wrapper = ray.get_actor("json-test_buffer") self.assertIsNotNone(file_wrapper) file_path = default_storage_path( - self.config.buffer.trainer_input.experience_buffer, self.config.buffer + self.config.buffer.trainer_input.experience_buffer.name, + self.config.buffer.trainer_input.experience_buffer.storage_type, ) with open(file_path, "r") as f: self.assertEqual(len(f.readlines()), 4) @@ -130,11 +131,13 @@ def setUp(self): os.makedirs(self.config.buffer.cache_dir, exist_ok=True) if os.path.exists( default_storage_path( - self.config.buffer.trainer_input.experience_buffer, self.config.buffer + self.config.buffer.trainer_input.experience_buffer.name, + self.config.buffer.trainer_input.experience_buffer.storage_type, ) ): os.remove( default_storage_path( - self.config.buffer.trainer_input.experience_buffer, self.config.buffer + self.config.buffer.trainer_input.experience_buffer.name, + self.config.buffer.trainer_input.experience_buffer.storage_type, ) ) diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index 1934b8fa6c..afe0e02c1c 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -6,7 +6,7 @@ from tests.tools import RayUnittestBaseAysnc from trinity.buffer.reader.sql_reader import SQLReader from trinity.buffer.writer.sql_writer import SQLWriter -from trinity.common.config import BufferConfig, StorageConfig +from trinity.common.config import StorageConfig from trinity.common.constants import StorageType from trinity.common.experience import Experience @@ -23,12 +23,10 @@ async def test_sql_buffer_read_write(self) -> None: schema_type="experience", path=f"sqlite:///{db_path}", storage_type=StorageType.SQL, + batch_size=read_batch_size, ) - config = BufferConfig( - train_batch_size=read_batch_size, - ) - sql_writer = SQLWriter(meta, config) - sql_reader = SQLReader(meta, config) + sql_writer = SQLWriter(meta) + sql_reader = SQLReader(meta) exps = [ Experience( tokens=torch.tensor([float(j) for j in range(i + 1)]), diff --git a/tests/cli/launcher_test.py b/tests/cli/launcher_test.py index 1b8ab142e8..294bcf50fc 100644 --- a/tests/cli/launcher_test.py +++ b/tests/cli/launcher_test.py @@ -21,6 +21,7 @@ TrainerInput, ) from trinity.common.constants import ( + CHECKPOINT_JOB_DIR_ENV_VAR, LOG_DIR_ENV_VAR, LOG_LEVEL_ENV_VAR, LOG_NODE_IP_ENV_VAR, @@ -118,6 +119,9 @@ def test_main_run_in_dlc(self, mock_init, mock_load, mock_both, mock_setup, mock runtime_env={ "env_vars": { launcher.PLUGIN_DIRS_ENV_VAR: "/path/to/plugins", + CHECKPOINT_JOB_DIR_ENV_VAR: os.path.join( + config.checkpoint_root_dir, config.project, config.name + ), LOG_DIR_ENV_VAR: config.log.save_dir, LOG_LEVEL_ENV_VAR: config.log.level, LOG_NODE_IP_ENV_VAR: "1", @@ -212,6 +216,9 @@ def test_multi_stage_run( runtime_env={ "env_vars": { launcher.PLUGIN_DIRS_ENV_VAR: "/path/to/plugins", + CHECKPOINT_JOB_DIR_ENV_VAR: os.path.join( + config.checkpoint_root_dir, config.project, config.name + ), LOG_DIR_ENV_VAR: os.path.join( config.checkpoint_root_dir, config.project, @@ -230,6 +237,9 @@ def test_multi_stage_run( runtime_env={ "env_vars": { launcher.PLUGIN_DIRS_ENV_VAR: "/path/to/plugins", + CHECKPOINT_JOB_DIR_ENV_VAR: os.path.join( + config.checkpoint_root_dir, config.project, config.name + ), LOG_DIR_ENV_VAR: os.path.join( config.checkpoint_root_dir, config.project, diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index 5e535a6d25..7bfa97d7a4 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -55,7 +55,6 @@ def __init__(self, buffer_config: BufferConfig, **kwargs): expert_buffer_config.train_batch_size = expert_batch_size self.expert_exp_buffer = get_buffer_reader( buffer_config.trainer_input.auxiliary_buffers[self.sft_dataset_name], - expert_buffer_config, ) async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: diff --git a/trinity/buffer/buffer.py b/trinity/buffer/buffer.py index a7d52e60a7..143130ab7c 100644 --- a/trinity/buffer/buffer.py +++ b/trinity/buffer/buffer.py @@ -2,20 +2,20 @@ """The buffer module""" from trinity.buffer.buffer_reader import BufferReader from trinity.buffer.buffer_writer import BufferWriter -from trinity.common.config import BufferConfig, StorageConfig +from trinity.common.config import StorageConfig from trinity.common.constants import StorageType -def get_buffer_reader(storage_config: StorageConfig, buffer_config: BufferConfig) -> BufferReader: +def get_buffer_reader(storage_config: StorageConfig) -> BufferReader: """Get a buffer reader for the given dataset name.""" if storage_config.storage_type == StorageType.SQL: from trinity.buffer.reader.sql_reader import SQLReader - return SQLReader(storage_config, buffer_config) + return SQLReader(storage_config) elif storage_config.storage_type == StorageType.QUEUE: from trinity.buffer.reader.queue_reader import QueueReader - return QueueReader(storage_config, buffer_config) + return QueueReader(storage_config) elif storage_config.storage_type == StorageType.FILE: from trinity.buffer.reader.file_reader import ( ExperienceFileReader, @@ -25,26 +25,26 @@ def get_buffer_reader(storage_config: StorageConfig, buffer_config: BufferConfig schema_type = storage_config.schema_type if schema_type: # only trainer input has schema type - return ExperienceFileReader(storage_config, buffer_config) + return ExperienceFileReader(storage_config) else: - return TaskFileReader(storage_config, buffer_config) + return TaskFileReader(storage_config) else: raise ValueError(f"{storage_config.storage_type} not supported.") -def get_buffer_writer(storage_config: StorageConfig, buffer_config: BufferConfig) -> BufferWriter: +def get_buffer_writer(storage_config: StorageConfig) -> BufferWriter: """Get a buffer writer for the given dataset name.""" if storage_config.storage_type == StorageType.SQL: from trinity.buffer.writer.sql_writer import SQLWriter - return SQLWriter(storage_config, buffer_config) + return SQLWriter(storage_config) elif storage_config.storage_type == StorageType.QUEUE: from trinity.buffer.writer.queue_writer import QueueWriter - return QueueWriter(storage_config, buffer_config) + return QueueWriter(storage_config) elif storage_config.storage_type == StorageType.FILE: from trinity.buffer.writer.file_writer import JSONWriter - return JSONWriter(storage_config, buffer_config) + return JSONWriter(storage_config) else: - raise ValueError(f"{storage_config.storage_type} not supported.") + raise ValueError(f"{storage_config} not supported.") diff --git a/trinity/buffer/pipelines/experience_pipeline.py b/trinity/buffer/pipelines/experience_pipeline.py index f92b4c638e..8c5c53d070 100644 --- a/trinity/buffer/pipelines/experience_pipeline.py +++ b/trinity/buffer/pipelines/experience_pipeline.py @@ -6,7 +6,6 @@ from trinity.buffer.storage.queue import is_database_url, is_json_file from trinity.common.config import ( AlgorithmConfig, - BufferConfig, Config, ExperiencePipelineConfig, StorageConfig, @@ -17,13 +16,11 @@ from trinity.utils.plugin_loader import load_plugins -def get_input_buffers( - pipeline_config: ExperiencePipelineConfig, buffer_config: BufferConfig -) -> Dict: +def get_input_buffers(pipeline_config: ExperiencePipelineConfig) -> Dict: """Get input buffers for the experience pipeline.""" input_buffers = {} for input_name, input_config in pipeline_config.inputs.items(): - buffer_reader = get_buffer_reader(input_config, buffer_config) + buffer_reader = get_buffer_reader(input_config) input_buffers[input_name] = buffer_reader return input_buffers @@ -37,8 +34,7 @@ def __init__(self, config: Config): self.logger = get_logger(f"{config.explorer.name}_experience_pipeline", in_ray_actor=True) load_plugins() pipeline_config = config.data_processor.experience_pipeline - buffer_config = config.buffer - self.input_store = self._init_input_storage(pipeline_config, buffer_config) # type: ignore [arg-type] + self.input_store = self._init_input_storage(pipeline_config) # type: ignore [arg-type] try: self.operators = ExperienceOperator.create_operators(pipeline_config.operators) except Exception as e: @@ -46,14 +42,12 @@ def __init__(self, config: Config): raise e self._set_algorithm_operators(config.algorithm) self.output = get_buffer_writer( - buffer_config.trainer_input.experience_buffer, # type: ignore [arg-type] - buffer_config, + config.buffer.trainer_input.experience_buffer, # type: ignore [arg-type] ) def _init_input_storage( self, pipeline_config: ExperiencePipelineConfig, - buffer_config: BufferConfig, ) -> Optional[BufferWriter]: """Initialize the input storage if it is not already set.""" if pipeline_config.save_input: @@ -66,7 +60,6 @@ def _init_input_storage( path=pipeline_config.input_save_path, wrap_in_ray=False, ), - buffer_config, ) elif is_database_url(pipeline_config.input_save_path): return get_buffer_writer( @@ -75,7 +68,6 @@ def _init_input_storage( path=pipeline_config.input_save_path, wrap_in_ray=False, ), - buffer_config, ) else: raise ValueError( diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index b79e87285d..93d2b5d54f 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -7,7 +7,7 @@ from trinity.buffer.buffer_reader import BufferReader from trinity.buffer.schema.formatter import FORMATTER -from trinity.common.config import BufferConfig, StorageConfig +from trinity.common.config import StorageConfig class DummyProgressBar: @@ -101,19 +101,19 @@ async def read_async(self, batch_size: Optional[int] = None): class ExperienceFileReader(BaseFileReader): """Reader for SFT / DPO file data.""" - def __init__(self, meta: StorageConfig, config: BufferConfig): - self.formatter = FORMATTER.get(meta.schema_type)( - tokenizer_path=config.tokenizer_path, format_config=meta.format + def __init__(self, config: StorageConfig): + self.formatter = FORMATTER.get(config.schema_type)( + tokenizer_path=config.tokenizer_path, format_config=config.format ) self.read_batch_size = config.train_batch_size self.dataset = _HFBatchReader( - load_dataset(meta.path, name=meta.subset_name, split=meta.split), - name=meta.name, + load_dataset(config.path, name=config.subset_name, split=config.split), + name=config.name, default_batch_size=self.read_batch_size, - total_epochs=meta.total_epochs, + total_epochs=config.total_epochs, drop_last=True, - total_steps=meta.total_steps, - enable_progress_bar=meta.enable_progress_bar, + total_steps=config.total_steps, + enable_progress_bar=config.enable_progress_bar, ) def read(self, batch_size: Optional[int] = None) -> List: @@ -126,26 +126,25 @@ def read(self, batch_size: Optional[int] = None) -> List: class TaskFileReader(BaseFileReader): - def __init__(self, meta: StorageConfig, config: BufferConfig): - self.meta = meta - self.name = meta.name - self.split = meta.split - subset_name = meta.subset_name - # disable datasets caching to avoid reuse old-version dataset + """A Reader for task file data.""" + + def __init__(self, config: StorageConfig): + self.config = config + self.name = config.name self.epoch = 0 datasets.disable_caching() self.read_batch_size = config.batch_size self.dataset = _HFBatchReader( - load_dataset(meta.path, name=subset_name, split=self.split), - name=meta.name, + load_dataset(self.config.path, name=self.config.subset_name, split=self.config.split), + name=self.config.name, default_batch_size=self.read_batch_size, - total_epochs=self.meta.total_epochs if not self.meta.is_eval else 1, - offset=self.meta.index, - drop_last=not self.meta.is_eval, - total_steps=meta.total_steps, - enable_progress_bar=meta.enable_progress_bar, + total_epochs=self.config.total_epochs if not self.config.is_eval else 1, + offset=self.config.index, + drop_last=not self.config.is_eval, + total_steps=self.config.total_steps, + enable_progress_bar=self.config.enable_progress_bar, ) - self.formatter = FORMATTER.get("task")(meta) + self.formatter = FORMATTER.get("task")(config) def read(self, batch_size: Optional[int] = None) -> List: batch_size = batch_size or self.read_batch_size @@ -155,3 +154,6 @@ def read(self, batch_size: Optional[int] = None) -> List: task = self.formatter.format(sample) tasks.append(task) return tasks + + def __len__(self) -> int: + return len(self.dataset.dataset) diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index e4ea695a21..46036bfe86 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -6,18 +6,18 @@ from trinity.buffer.buffer_reader import BufferReader from trinity.buffer.storage.queue import QueueStorage -from trinity.common.config import BufferConfig, StorageConfig +from trinity.common.config import StorageConfig from trinity.common.constants import StorageType class QueueReader(BufferReader): """Reader of the Queue buffer.""" - def __init__(self, storage_config: StorageConfig, config: BufferConfig): - assert storage_config.storage_type == StorageType.QUEUE - self.timeout = storage_config.max_read_timeout + def __init__(self, config: StorageConfig): + assert config.storage_type == StorageType.QUEUE + self.timeout = config.max_read_timeout self.read_batch_size = config.train_batch_size - self.queue = QueueStorage.get_wrapper(storage_config, config) + self.queue = QueueStorage.get_wrapper(config) def read(self, batch_size: Optional[int] = None) -> List: try: diff --git a/trinity/buffer/reader/sql_reader.py b/trinity/buffer/reader/sql_reader.py index 7e45b842d2..d44d2f244f 100644 --- a/trinity/buffer/reader/sql_reader.py +++ b/trinity/buffer/reader/sql_reader.py @@ -6,17 +6,17 @@ from trinity.buffer.buffer_reader import BufferReader from trinity.buffer.storage.sql import SQLStorage -from trinity.common.config import BufferConfig, StorageConfig +from trinity.common.config import StorageConfig from trinity.common.constants import StorageType class SQLReader(BufferReader): """Reader of the SQL buffer.""" - def __init__(self, meta: StorageConfig, config: BufferConfig) -> None: - assert meta.storage_type == StorageType.SQL - self.wrap_in_ray = meta.wrap_in_ray - self.storage = SQLStorage.get_wrapper(meta, config) + def __init__(self, config: StorageConfig) -> None: + assert config.storage_type == StorageType.SQL + self.wrap_in_ray = config.wrap_in_ray + self.storage = SQLStorage.get_wrapper(config) def read(self, batch_size: Optional[int] = None) -> List: if self.wrap_in_ray: diff --git a/trinity/buffer/schema/formatter.py b/trinity/buffer/schema/formatter.py index 4284a321ad..976d23eb76 100644 --- a/trinity/buffer/schema/formatter.py +++ b/trinity/buffer/schema/formatter.py @@ -38,7 +38,6 @@ class TaskFormatter: def __init__(self, config: StorageConfig): self.config = config - self.is_eval = config.is_eval self.default_workflow_cls = WORKFLOWS.get(config.default_workflow_type) # type: ignore self.default_reward_fn_cls = REWARD_FUNCTIONS.get(config.default_reward_fn_type) # type: ignore self.workflow_key = config.format.workflow_key @@ -65,7 +64,7 @@ def format(self, sample: Dict) -> Task: rollout_args=self.config.rollout_args, workflow_args=self.config.workflow_args, reward_fn_args=self.config.reward_fn_args, - is_eval=self.is_eval, + is_eval=self.config.is_eval, raw_task=sample, ) diff --git a/trinity/buffer/storage/file.py b/trinity/buffer/storage/file.py index 9af8fc5520..3de80bcde7 100644 --- a/trinity/buffer/storage/file.py +++ b/trinity/buffer/storage/file.py @@ -6,7 +6,7 @@ import ray from trinity.buffer.utils import default_storage_path -from trinity.common.config import BufferConfig, StorageConfig +from trinity.common.config import StorageConfig from trinity.common.experience import EID, Experience from trinity.common.workflows import Task @@ -33,34 +33,32 @@ class FileStorage: StorageType.QUEUE instead. """ - def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: - if not storage_config.path: - storage_config.path = default_storage_path(storage_config, config) - ext = os.path.splitext(storage_config.path)[-1] + def __init__(self, config: StorageConfig) -> None: + if not config.path: + config.path = default_storage_path(config.name, config.storage_type) + ext = os.path.splitext(config.path)[-1] if ext != ".jsonl" and ext != ".json": - raise ValueError( - f"File path must end with '.json' or '.jsonl', got {storage_config.path}" - ) - path_dir = os.path.dirname(os.path.abspath(storage_config.path)) + raise ValueError(f"File path must end with '.json' or '.jsonl', got {config.path}") + path_dir = os.path.dirname(os.path.abspath(config.path)) os.makedirs(path_dir, exist_ok=True) - self.file = open(storage_config.path, "a", encoding="utf-8") + self.file = open(config.path, "a", encoding="utf-8") self.encoder = _Encoder(ensure_ascii=False) self.ref_count = 0 @classmethod - def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): - if storage_config.wrap_in_ray: + def get_wrapper(cls, config: StorageConfig): + if config.wrap_in_ray: return ( ray.remote(cls) .options( - name=f"json-{storage_config.name}", - namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace, + name=f"json-{config.name}", + namespace=config.ray_namespace or ray.get_runtime_context().namespace, get_if_exists=True, ) - .remote(storage_config, config) + .remote(config) ) else: - return cls(storage_config, config) + return cls(config) def write(self, data: List) -> None: for item in data: diff --git a/trinity/buffer/storage/queue.py b/trinity/buffer/storage/queue.py index 06257c3cab..324122746c 100644 --- a/trinity/buffer/storage/queue.py +++ b/trinity/buffer/storage/queue.py @@ -11,7 +11,7 @@ import ray from sortedcontainers import SortedDict -from trinity.common.config import BufferConfig, StorageConfig +from trinity.common.config import StorageConfig from trinity.common.constants import StorageType from trinity.common.experience import Experience from trinity.utils.log import get_logger @@ -93,19 +93,19 @@ def stopped(self) -> bool: """Check if there is no more data to read.""" @classmethod - def get_queue(cls, storage_config: StorageConfig, config: BufferConfig) -> "QueueBuffer": + def get_queue(cls, config: StorageConfig) -> "QueueBuffer": """Get a queue instance based on the storage configuration.""" logger = get_logger(__name__) - if storage_config.use_priority_queue: - reuse_cooldown_time = storage_config.reuse_cooldown_time - replay_buffer_kwargs = storage_config.replay_buffer_kwargs - capacity = storage_config.capacity + if config.use_priority_queue: + reuse_cooldown_time = config.reuse_cooldown_time + replay_buffer_kwargs = config.replay_buffer_kwargs + capacity = config.capacity logger.info( f"Using AsyncPriorityQueue with capacity {capacity}, reuse_cooldown_time {reuse_cooldown_time}." ) return AsyncPriorityQueue(capacity, reuse_cooldown_time, **replay_buffer_kwargs) else: - return AsyncQueue(capacity=storage_config.capacity) + return AsyncQueue(capacity=config.capacity) class AsyncQueue(asyncio.Queue, QueueBuffer): @@ -258,24 +258,24 @@ def stopped(self) -> bool: class QueueStorage: """An wrapper of a async queue.""" - def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: - self.logger = get_logger(f"queue_{storage_config.name}", in_ray_actor=True) + def __init__(self, config: StorageConfig) -> None: + self.logger = get_logger(f"queue_{config.name}", in_ray_actor=True) self.config = config - self.capacity = storage_config.capacity - self.queue = QueueBuffer.get_queue(storage_config, config) - st_config = deepcopy(storage_config) + self.capacity = config.capacity + self.queue = QueueBuffer.get_queue(config) + st_config = deepcopy(config) st_config.wrap_in_ray = False if st_config.path: if is_database_url(st_config.path): from trinity.buffer.writer.sql_writer import SQLWriter st_config.storage_type = StorageType.SQL - self.writer = SQLWriter(st_config, self.config) + self.writer = SQLWriter(st_config) elif is_json_file(st_config.path): from trinity.buffer.writer.file_writer import JSONWriter st_config.storage_type = StorageType.FILE - self.writer = JSONWriter(st_config, self.config) + self.writer = JSONWriter(st_config) else: self.logger.warning("Unknown supported storage path: %s", st_config.path) self.writer = None @@ -283,7 +283,7 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: from trinity.buffer.writer.file_writer import JSONWriter st_config.storage_type = StorageType.FILE - self.writer = JSONWriter(st_config, self.config) + self.writer = JSONWriter(st_config) self.logger.warning(f"Save experiences in {st_config.path}.") self.ref_count = 0 self.exp_pool = deque() # A pool to store experiences @@ -335,14 +335,14 @@ async def get_batch(self, batch_size: int, timeout: float) -> List: return [self.exp_pool.popleft() for _ in range(batch_size)] @classmethod - def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): + def get_wrapper(cls, config: StorageConfig): """Get the queue actor.""" return ( ray.remote(cls) .options( - name=f"queue-{storage_config.name}", - namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace, + name=f"queue-{config.name}", + namespace=config.ray_namespace or ray.get_runtime_context().namespace, get_if_exists=True, ) - .remote(storage_config, config) + .remote(config) ) diff --git a/trinity/buffer/storage/sql.py b/trinity/buffer/storage/sql.py index 3254cd663e..790813570d 100644 --- a/trinity/buffer/storage/sql.py +++ b/trinity/buffer/storage/sql.py @@ -12,7 +12,7 @@ from trinity.buffer.schema import init_engine from trinity.buffer.schema.formatter import FORMATTER, TaskFormatter from trinity.buffer.utils import default_storage_path, retry_session -from trinity.common.config import BufferConfig, StorageConfig +from trinity.common.config import StorageConfig from trinity.common.experience import Experience from trinity.common.rewards import REWARD_FUNCTIONS from trinity.common.workflows import WORKFLOWS, Task @@ -30,43 +30,45 @@ class SQLStorage: set `wrap_in_ray` to `True`. """ - def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: - self.logger = get_logger(f"sql_{storage_config.name}", in_ray_actor=True) - if not storage_config.path: - storage_config.path = default_storage_path(storage_config, config) + def __init__(self, config: StorageConfig) -> None: + self.logger = get_logger(f"sql_{config.name}", in_ray_actor=True) + if not config.path: + config.path = default_storage_path( + storage_name=config.name, storage_type=config.storage_type + ) self.engine, self.table_model_cls = init_engine( - db_url=storage_config.path, - table_name=storage_config.name, - schema_type=storage_config.schema_type, + db_url=config.path, + table_name=config.name, + schema_type=config.schema_type, ) - self.logger.info(f"Init SQL storage at {storage_config.path}") + self.logger.info(f"Init SQL storage at {config.path}") self.session = sessionmaker(bind=self.engine) - self.max_retry_times = storage_config.max_retry_times - self.max_retry_interval = storage_config.max_retry_interval + self.max_retry_times = config.max_retry_times + self.max_retry_interval = config.max_retry_interval self.ref_count = 0 self.stopped = False # Assume that the auto-increment ID starts counting from 1, so the default offset should be 0. - self.offset = storage_config.index + self.offset = config.index @classmethod - def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): - if storage_config.schema_type is None: + def get_wrapper(cls, config: StorageConfig): + if config.schema_type is None: storage_cls = SQLTaskStorage else: storage_cls = SQLExperienceStorage - if storage_config.wrap_in_ray: + if config.wrap_in_ray: return ( ray.remote(storage_cls) .options( - name=f"sql-{storage_config.name}", - namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace, + name=f"sql-{config.name}", + namespace=config.ray_namespace or ray.get_runtime_context().namespace, get_if_exists=True, max_concurrency=5, ) - .remote(storage_config, config) + .remote(config) ) else: - return storage_cls(storage_config, config) + return storage_cls(config) @abstractmethod def write(self, data: List) -> None: @@ -90,12 +92,12 @@ def release(self) -> int: class SQLExperienceStorage(SQLStorage): """Used as trainer input.""" - def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: - super().__init__(storage_config, config) - self.batch_size = config.train_batch_size - self.max_timeout = storage_config.max_read_timeout + def __init__(self, config: StorageConfig) -> None: + super().__init__(config) + self.max_timeout = config.max_read_timeout + self.batch_size = config.batch_size # TODO: optimize the following logic - if storage_config.schema_type == "experience": + if config.schema_type == "experience": # NOTE: consistent with the old version of experience buffer self._read_method = self._read_priority else: @@ -191,15 +193,10 @@ def read(self, batch_size: Optional[int] = None) -> List[Experience]: return self._read_method(batch_size) @classmethod - def load_from_dataset( - cls, dataset: Dataset, storage_config: StorageConfig, config: BufferConfig - ) -> "SQLExperienceStorage": - storage = cls( - storage_config=storage_config, - config=config, - ) - formatter = FORMATTER.get(storage_config.schema_type)( - tokenizer_path=config.tokenizer_path, format_config=storage_config.format + def load_from_dataset(cls, dataset: Dataset, config: StorageConfig) -> "SQLExperienceStorage": + storage = cls(config) + formatter = FORMATTER.get(config.schema_type)( + tokenizer_path=config.tokenizer_path, format_config=config.format ) batch_size = storage.batch_size batch = [] @@ -216,20 +213,20 @@ def load_from_dataset( class SQLTaskStorage(SQLStorage): """Used as explorer input.""" - def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: - super().__init__(storage_config, config) + def __init__(self, config: StorageConfig) -> None: + super().__init__(config) self.batch_size = config.batch_size - self.is_eval = storage_config.is_eval - self.default_workflow_cls = WORKFLOWS.get(storage_config.default_workflow_type) # type: ignore - self.default_reward_fn_cls = REWARD_FUNCTIONS.get(storage_config.default_reward_fn_type) # type: ignore - self.formatter = TaskFormatter(storage_config) - self.offset = storage_config.index - if storage_config.total_steps: - self.total_samples = self.batch_size * storage_config.total_steps + self.is_eval = config.is_eval + self.default_workflow_cls = WORKFLOWS.get(config.default_workflow_type) # type: ignore + self.default_reward_fn_cls = REWARD_FUNCTIONS.get(config.default_reward_fn_type) # type: ignore + self.formatter = TaskFormatter(config) + self.offset = config.index + if config.total_steps: + self.total_samples = self.batch_size * config.total_steps else: - if storage_config.total_epochs > 1: + if config.total_epochs > 1: self.logger.warning( - f"SQL Storage do not support total_epochs, the value {storage_config.total_epochs} will be ignored" + f"SQL Storage do not support total_epochs, the value {config.total_epochs} will be ignored" ) self.total_samples = float("inf") @@ -260,13 +257,8 @@ def read(self, batch_size: Optional[int] = None) -> List[Task]: return [self.formatter.format(item.raw_task) for item in results] @classmethod - def load_from_dataset( - cls, dataset: Dataset, storage_config: StorageConfig, config: BufferConfig - ) -> "SQLTaskStorage": - storage = cls( - storage_config=storage_config, - config=config, - ) + def load_from_dataset(cls, dataset: Dataset, config: StorageConfig) -> "SQLTaskStorage": + storage = cls(config) batch_size = config.batch_size batch = [] for item in dataset: diff --git a/trinity/buffer/utils.py b/trinity/buffer/utils.py index 7db8dde867..9c9a32dc8c 100644 --- a/trinity/buffer/utils.py +++ b/trinity/buffer/utils.py @@ -2,8 +2,7 @@ import time from contextlib import contextmanager -from trinity.common.config import BufferConfig, StorageConfig -from trinity.common.constants import StorageType +from trinity.common.constants import CHECKPOINT_JOB_DIR_ENV_VAR, StorageType from trinity.utils.log import get_logger @@ -37,16 +36,16 @@ def retry_session(session_maker, max_retry_times: int, max_retry_interval: float session.close() -def default_storage_path(storage_config: StorageConfig, buffer_config: BufferConfig) -> str: - if buffer_config.cache_dir is None: - raise ValueError("Please call config.check_and_update() before using.") - if storage_config.storage_type == StorageType.SQL: - return "sqlite:///" + os.path.join( - buffer_config.cache_dir, - f"{storage_config.name}.db", +def default_storage_path(storage_name: str, storage_type: StorageType) -> str: + checkpoint_dir = os.environ.get(CHECKPOINT_JOB_DIR_ENV_VAR, None) + if checkpoint_dir is None: + raise ValueError( + f"Environment variable {CHECKPOINT_JOB_DIR_ENV_VAR} is not set. " + "This should not happen when using `trinity run` command." ) + storage_dir = os.path.join(checkpoint_dir, "buffer") + os.makedirs(storage_dir, exist_ok=True) + if storage_type == StorageType.SQL: + return "sqlite:///" + os.path.join(storage_dir, f"{storage_name}.db") else: - return os.path.join( - buffer_config.cache_dir, - f"{storage_config.name}.jsonl", - ) + return os.path.join(storage_dir, f"{storage_name}.jsonl") diff --git a/trinity/buffer/writer/file_writer.py b/trinity/buffer/writer/file_writer.py index 5a579bf59c..9ccbd718f0 100644 --- a/trinity/buffer/writer/file_writer.py +++ b/trinity/buffer/writer/file_writer.py @@ -4,15 +4,15 @@ from trinity.buffer.buffer_writer import BufferWriter from trinity.buffer.storage.file import FileStorage -from trinity.common.config import BufferConfig, StorageConfig +from trinity.common.config import StorageConfig from trinity.common.constants import StorageType class JSONWriter(BufferWriter): - def __init__(self, meta: StorageConfig, config: BufferConfig): - assert meta.storage_type == StorageType.FILE - self.writer = FileStorage.get_wrapper(meta, config) - self.wrap_in_ray = meta.wrap_in_ray + def __init__(self, config: StorageConfig): + assert config.storage_type == StorageType.FILE + self.writer = FileStorage.get_wrapper(config) + self.wrap_in_ray = config.wrap_in_ray def write(self, data: List) -> None: if self.wrap_in_ray: diff --git a/trinity/buffer/writer/queue_writer.py b/trinity/buffer/writer/queue_writer.py index 7e4f4a9ca1..2d62511c90 100644 --- a/trinity/buffer/writer/queue_writer.py +++ b/trinity/buffer/writer/queue_writer.py @@ -5,17 +5,16 @@ from trinity.buffer.buffer_writer import BufferWriter from trinity.buffer.storage.queue import QueueStorage -from trinity.common.config import BufferConfig, StorageConfig +from trinity.common.config import StorageConfig from trinity.common.constants import StorageType class QueueWriter(BufferWriter): """Writer of the Queue buffer.""" - def __init__(self, meta: StorageConfig, config: BufferConfig): - assert meta.storage_type == StorageType.QUEUE - self.config = config - self.queue = QueueStorage.get_wrapper(meta, config) + def __init__(self, config: StorageConfig): + assert config.storage_type == StorageType.QUEUE + self.queue = QueueStorage.get_wrapper(config) def write(self, data: List) -> None: ray.get(self.queue.put_batch.remote(data)) diff --git a/trinity/buffer/writer/sql_writer.py b/trinity/buffer/writer/sql_writer.py index ebe4ed0267..eeec7be55e 100644 --- a/trinity/buffer/writer/sql_writer.py +++ b/trinity/buffer/writer/sql_writer.py @@ -4,18 +4,18 @@ from trinity.buffer.buffer_writer import BufferWriter from trinity.buffer.storage.sql import SQLStorage -from trinity.common.config import BufferConfig, StorageConfig +from trinity.common.config import StorageConfig from trinity.common.constants import StorageType class SQLWriter(BufferWriter): """Writer of the SQL buffer.""" - def __init__(self, meta: StorageConfig, config: BufferConfig) -> None: - assert meta.storage_type == StorageType.SQL + def __init__(self, config: StorageConfig) -> None: + assert config.storage_type == StorageType.SQL # we only support write RFT algorithm buffer for now - self.wrap_in_ray = meta.wrap_in_ray - self.db_wrapper = SQLStorage.get_wrapper(meta, config) + self.wrap_in_ray = config.wrap_in_ray + self.db_wrapper = SQLStorage.get_wrapper(config) def write(self, data: list) -> None: if self.wrap_in_ray: diff --git a/trinity/common/config.py b/trinity/common/config.py index f96077baf3..1337ed2e84 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -14,6 +14,7 @@ from omegaconf import OmegaConf from trinity.common.constants import ( + CHECKPOINT_JOB_DIR_ENV_VAR, EXPLORER_NAME, LOG_DIR_ENV_VAR, LOG_LEVEL_ENV_VAR, @@ -112,9 +113,21 @@ class LoRAConfig: target_modules: str = "all-linear" +@dataclass +class ReplayBufferConfig: + """Config for replay buffer used in StorageType.QUEUE with use_priority_queue=True.""" + + priority_fn: str = "linear_decay" + reuse_cooldown_time: Optional[float] = None + + priority_fn_args: Dict = field(default_factory=lambda: {"decay": 2.0}) + + @dataclass class StorageConfig: - """Storage config.""" + """Storage config. + Used for both taskset and experience buffer. + """ name: str = "" storage_type: StorageType = StorageType.FILE @@ -167,10 +180,97 @@ class StorageConfig: # ! DO NOT SET, automatically set from buffer.total_steps total_steps: Optional[int] = None # automatically set + # ! DO NOT SET, automatically set from buffer.batch_size / train_batch_size + batch_size: int = 0 + # ! DO NOT SET, automatically set corresponding to train/eval is_eval: bool = False +class TasksetConfig: + name: str = "" + storage_type: StorageType = StorageType.FILE + path: Optional[str] = None + + default_workflow_type: Optional[str] = None + default_reward_fn_type: Optional[str] = None + rollout_args: GenerationConfig = field(default_factory=GenerationConfig) + workflow_args: dict = field(default_factory=dict) + reward_fn_args: dict = field(default_factory=dict) + + # used for StorageType.FILE + split: str = "train" + subset_name: Optional[str] = None + format: FormatConfig = field(default_factory=FormatConfig) + + # used for StorageType.SQL + max_retry_times: int = 3 + max_retry_interval: int = 1 + + enable_progress_bar: bool = False + + # ! DO NOT SET, automatically load from checkpoint + index: int = 0 + # ! DO NOT SET, automatically set from algorithm.repeat_times + repeat_times: int = 1 + # ! DO NOT SET, automatically set based on train/eval + is_eval: bool = False + # ! DO NOT SET, automatically set from buffer.batch_size + batch_size: int = 0 + + def to_storage_config(self) -> StorageConfig: + storage_config = StorageConfig( + name=self.name, + storage_type=self.storage_type, + path=self.path, + repeat_times=self.repeat_times, + index=self.index, + split=self.split, + subset_name=self.subset_name, + format=self.format, + max_retry_times=self.max_retry_times, + max_retry_interval=self.max_retry_interval, + default_workflow_type=self.default_workflow_type, + default_reward_fn_type=self.default_reward_fn_type, + rollout_args=self.rollout_args, + workflow_args=self.workflow_args, + reward_fn_args=self.reward_fn_args, + enable_progress_bar=self.enable_progress_bar, + is_eval=self.is_eval, + batch_size=self.batch_size, + ) + return storage_config + + +class ExperienceBufferConfig: + name: str = "" + storage_type: StorageType = StorageType.FILE + + # used for StorageType.FILE + split: str = "train" + subset_name: Optional[str] = None + format: FormatConfig = field(default_factory=FormatConfig) + + # used for StorageType.QUEUE + capacity: int = 10000 + max_read_timeout: float = 1800 + use_priority_queue: bool = False + reuse_cooldown_time: Optional[float] = None + replay_buffer_kwargs: dict = field( + default_factory=lambda: {"priority_fn": "linear_decay", "decay": 2.0} + ) + + # used for StorageType.SQL + max_retry_times: int = 3 + max_retry_interval: int = 1 + + # ! DO NOT SET, automatically load from checkpoint + index: int = 0 + + # ! DO NOT SET, automatically set from buffer.batch_size + batch_size: int = 0 + + @dataclass class OperatorConfig: name: str = "" @@ -1123,6 +1223,7 @@ def get_envs(self) -> Dict[str, str]: """Get the environment variables from the config.""" return { PLUGIN_DIRS_ENV_VAR: os.getenv(PLUGIN_DIRS_ENV_VAR, ""), + CHECKPOINT_JOB_DIR_ENV_VAR: self.checkpoint_job_dir, LOG_LEVEL_ENV_VAR: self.log.level, LOG_DIR_ENV_VAR: self.log.save_dir, LOG_NODE_IP_ENV_VAR: "1" if self.log.group_by_node else "0", diff --git a/trinity/common/constants.py b/trinity/common/constants.py index ad092603d2..4ba2b505b0 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -12,6 +12,7 @@ # trinity env var names CHECKPOINT_ROOT_DIR_ENV_VAR = "TRINITY_CHECKPOINT_ROOT_DIR" +CHECKPOINT_JOB_DIR_ENV_VAR = "TRINITY_CHECKPOINT_JOB_DIR" PREVIOUS_STAGE_CHECKPOINT_DIR_ENV_VAR = "TRINITY_PREV_STAGE_CKPT_DIR" MODEL_PATH_ENV_VAR = "TRINITY_MODEL_PATH" TASKSET_PATH_ENV_VAR = "TRINITY_TASKSET_PATH" diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index a10b523af2..e4bc5bf969 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -51,7 +51,7 @@ def __init__(self, config: Config): self.experience_pipeline = self._init_experience_pipeline() self.config.buffer.explorer_input.taskset.index = explorer_state.get("latest_task_index", 0) self.taskset = ( - get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer) + get_buffer_reader(self.config.buffer.explorer_input.taskset) if self.config.mode != "serve" else None ) diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 187d0d5adf..9c3e0f6c0e 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -171,7 +171,7 @@ def __init__( ) -> None: model, auxiliary_models = get_debug_inference_model(config) super().__init__(config, model, auxiliary_models, 0) - self.taskset = get_buffer_reader(config.buffer.explorer_input.taskset, config.buffer) + self.taskset = get_buffer_reader(config.buffer.explorer_input.taskset) self.output_file = output_file async def debug(self) -> None: From 69aa53c1d6305de6bda13dfaca0168429d52f9aa Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 23 Oct 2025 17:54:05 +0800 Subject: [PATCH 02/21] fix buffer tests --- tests/buffer/experience_storage_test.py | 22 +-- tests/buffer/file_test.py | 64 +++----- tests/buffer/queue_test.py | 47 +++--- tests/buffer/sql_test.py | 8 +- tests/buffer/task_storage_test.py | 6 +- tests/explorer/explorer_test.py | 10 +- tests/explorer/scheduler_test.py | 4 +- tests/manager/synchronizer_test.py | 12 +- tests/service/data_juicer_test.py | 13 +- tests/tools.py | 21 ++- tests/trainer/trainer_test.py | 8 +- trinity/buffer/buffer.py | 18 ++- trinity/buffer/reader/queue_reader.py | 2 +- trinity/buffer/storage/file.py | 3 +- trinity/buffer/storage/sql.py | 6 +- trinity/buffer/utils.py | 17 --- trinity/common/config.py | 188 ++++++++++++++++-------- 17 files changed, 245 insertions(+), 204 deletions(-) diff --git a/tests/buffer/experience_storage_test.py b/tests/buffer/experience_storage_test.py index 8ba591e13f..6cf648bc70 100644 --- a/tests/buffer/experience_storage_test.py +++ b/tests/buffer/experience_storage_test.py @@ -10,7 +10,7 @@ from tests.tools import RayUnittestBaseAysnc from trinity.buffer.reader.sql_reader import SQLReader from trinity.buffer.writer.sql_writer import SQLWriter -from trinity.common.config import BufferConfig, StorageConfig +from trinity.common.config import ExperienceBufferConfig from trinity.common.constants import StorageType from trinity.common.experience import EID, Experience @@ -23,24 +23,22 @@ def setUp(self): self.put_batch_size = 2 self.train_batch_size = 4 - self.config = BufferConfig( - train_batch_size=self.train_batch_size, - ) if os.path.exists(DB_PATH): os.remove(DB_PATH) @parameterized.expand([("sft",), ("dpo",)]) async def test_sql_storage(self, schema_type): - meta = StorageConfig( + config = ExperienceBufferConfig( name="test_storage", schema_type=schema_type, storage_type=StorageType.SQL, max_read_timeout=3, path=f"sqlite:///{DB_PATH}", + batch_size=self.train_batch_size, ) - - writer = SQLWriter(meta, self.config) - reader = SQLReader(meta, self.config) + config = config.to_storage_config() + writer = SQLWriter(config) + reader = SQLReader(config) self.assertEqual(await writer.acquire(), 1) exps = [ Experience( @@ -90,15 +88,17 @@ def thread_read(reader, result_queue): self.assertRaises(StopIteration, reader.read, batch_size=1) async def test_sql_experience_buffer(self): - meta = StorageConfig( + config = ExperienceBufferConfig( name="test_storage", schema_type="experience", storage_type=StorageType.SQL, max_read_timeout=3, path=f"sqlite:///{DB_PATH}", + batch_size=self.train_batch_size, ) - writer = SQLWriter(meta, self.config) - reader = SQLReader(meta, self.config) + config = config.to_storage_config() + writer = SQLWriter(config) + reader = SQLReader(config) self.assertEqual(await writer.acquire(), 1) for idx in range(self.total_num // self.put_batch_size): exps = [ diff --git a/tests/buffer/file_test.py b/tests/buffer/file_test.py index 6fc2868b5c..995f4df25a 100644 --- a/tests/buffer/file_test.py +++ b/tests/buffer/file_test.py @@ -1,4 +1,5 @@ import os +import shutil import unittest import ray @@ -9,28 +10,14 @@ get_unittest_dataset_config, ) from trinity.buffer.buffer import get_buffer_reader, get_buffer_writer -from trinity.buffer.utils import default_storage_path -from trinity.common.config import StorageConfig +from trinity.common.config import ExperienceBufferConfig from trinity.common.constants import StorageType class TestFileBuffer(unittest.IsolatedAsyncioTestCase): - temp_output_path = "tmp/test_file_buffer/" - - @classmethod - def setUpClass(cls): - super().setUpClass() - os.makedirs(cls.temp_output_path, exist_ok=True) - - @classmethod - def tearDownClass(cls): - super().tearDownClass() - if os.path.exists(cls.temp_output_path): - os.system(f"rm -rf {cls.temp_output_path}") - def test_file_reader(self): # noqa: C901 """Test file reader.""" - reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer) + reader = get_buffer_reader(self.config.buffer.explorer_input.taskset) tasks = [] while True: @@ -43,7 +30,9 @@ def test_file_reader(self): # noqa: C901 # test epoch and offset self.config.buffer.explorer_input.taskset.total_epochs = 2 self.config.buffer.explorer_input.taskset.index = 4 - reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer) + reader = get_buffer_reader( + self.config.buffer.explorer_input.taskset, + ) tasks = [] while True: try: @@ -55,7 +44,7 @@ def test_file_reader(self): # noqa: C901 # test total steps and offset self.config.buffer.explorer_input.taskset.total_steps = 5 self.config.buffer.explorer_input.taskset.index = 8 - reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer) + reader = get_buffer_reader(self.config.buffer.explorer_input.taskset) tasks = [] while True: try: @@ -68,7 +57,7 @@ def test_file_reader(self): # noqa: C901 self.config.buffer.explorer_input.taskset.total_steps = None self.config.buffer.explorer_input.taskset.total_epochs = 3 self.config.buffer.explorer_input.taskset.index = 20 - reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer) + reader = get_buffer_reader(self.config.buffer.explorer_input.taskset) tasks = [] while True: try: @@ -80,7 +69,7 @@ def test_file_reader(self): # noqa: C901 # test offset > dataset_len with total_steps self.config.buffer.explorer_input.taskset.total_steps = 10 self.config.buffer.explorer_input.taskset.index = 24 - reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer) + reader = get_buffer_reader(self.config.buffer.explorer_input.taskset) tasks = [] while True: try: @@ -90,9 +79,7 @@ def test_file_reader(self): # noqa: C901 self.assertEqual(len(tasks), 40 - 24) async def test_file_writer(self): - writer = get_buffer_writer( - self.config.buffer.trainer_input.experience_buffer, self.config.buffer - ) + writer = get_buffer_writer(self.config.buffer.trainer_input.experience_buffer) await writer.acquire() writer.write( [ @@ -109,10 +96,7 @@ async def test_file_writer(self): await writer.release() file_wrapper = ray.get_actor("json-test_buffer") self.assertIsNotNone(file_wrapper) - file_path = default_storage_path( - self.config.buffer.trainer_input.experience_buffer.name, - self.config.buffer.trainer_input.experience_buffer.storage_type, - ) + file_path = self.config.buffer.trainer_input.experience_buffer.path with open(file_path, "r") as f: self.assertEqual(len(f.readlines()), 4) @@ -121,23 +105,15 @@ def setUp(self): self.config.checkpoint_root_dir = get_checkpoint_path() dataset_config = get_unittest_dataset_config("countdown", "train") self.config.buffer.explorer_input.taskset = dataset_config - self.config.buffer.trainer_input.experience_buffer = StorageConfig( + self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( name="test_buffer", storage_type=StorageType.FILE ) - self.config.buffer.trainer_input.experience_buffer.name = "test_buffer" - self.config.buffer.cache_dir = os.path.join( - self.config.checkpoint_root_dir, self.config.project, self.config.name, "buffer" - ) + self.config.check_and_update() + ray.init(ignore_reinit_error=True, runtime_env={"env_vars": self.config.get_envs()}) os.makedirs(self.config.buffer.cache_dir, exist_ok=True) - if os.path.exists( - default_storage_path( - self.config.buffer.trainer_input.experience_buffer.name, - self.config.buffer.trainer_input.experience_buffer.storage_type, - ) - ): - os.remove( - default_storage_path( - self.config.buffer.trainer_input.experience_buffer.name, - self.config.buffer.trainer_input.experience_buffer.storage_type, - ) - ) + file_path = self.config.buffer.trainer_input.experience_buffer.path + if os.path.exists(file_path): + shutil.rmtree(file_path) + + def tearDown(self): + ray.shutdown() diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py index 7f26bb166d..1fdfe95701 100644 --- a/tests/buffer/queue_test.py +++ b/tests/buffer/queue_test.py @@ -10,7 +10,7 @@ from tests.tools import RayUnittestBaseAysnc from trinity.buffer.reader.queue_reader import QueueReader from trinity.buffer.writer.queue_writer import QueueWriter -from trinity.common.config import BufferConfig, StorageConfig +from trinity.common.config import ExperienceBufferConfig from trinity.common.constants import StorageType from trinity.common.experience import Experience @@ -31,16 +31,18 @@ class TestQueueBuffer(RayUnittestBaseAysnc): ] ) async def test_queue_buffer(self, name, use_priority_queue): - meta = StorageConfig( + config = ExperienceBufferConfig( name=name, schema_type="experience", storage_type=StorageType.QUEUE, max_read_timeout=3, path=BUFFER_FILE_PATH, use_priority_queue=use_priority_queue, + batch_size=self.train_batch_size, ) - writer = QueueWriter(meta, self.config) - reader = QueueReader(meta, self.config) + config = config.to_storage_config() + writer = QueueWriter(config) + reader = QueueReader(config) self.assertEqual(await writer.acquire(), 1) exps = [ Experience( @@ -94,8 +96,8 @@ def thread_read(reader, result_queue): async def test_priority_queue_capacity(self): # test priority queue capacity - self.config.train_batch_size = 4 - meta = StorageConfig( + self.train_batch_size = 4 + config = ExperienceBufferConfig( name="test_buffer_small", schema_type="experience", storage_type=StorageType.QUEUE, @@ -104,9 +106,11 @@ async def test_priority_queue_capacity(self): path=BUFFER_FILE_PATH, use_priority_queue=True, replay_buffer_kwargs={"priority_fn": "linear_decay", "decay": 0.6}, + batch_size=self.train_batch_size, ) - writer = QueueWriter(meta, self.config) - reader = QueueReader(meta, self.config) + config = config.to_storage_config() + writer = QueueWriter(config) + reader = QueueReader(config) for i in range(12): writer.write( @@ -149,16 +153,18 @@ async def test_priority_queue_capacity(self): async def test_queue_buffer_capacity(self): # test queue capacity - meta = StorageConfig( + config = ExperienceBufferConfig( name="test_buffer_small", schema_type="experience", storage_type=StorageType.QUEUE, max_read_timeout=3, capacity=4, path=BUFFER_FILE_PATH, + batch_size=self.train_batch_size, ) - writer = QueueWriter(meta, self.config) - reader = QueueReader(meta, self.config) + config = config.to_storage_config() + writer = QueueWriter(config) + reader = QueueReader(config) writer.write([{"content": "hello"}]) writer.write([{"content": "hi"}]) writer.write([{"content": "hello"}]) @@ -178,7 +184,7 @@ def write_blocking_call(): async def test_priority_queue_buffer_reuse(self): # test experience replay - meta = StorageConfig( + config = ExperienceBufferConfig( name="test_buffer_small", schema_type="experience", storage_type=StorageType.QUEUE, @@ -188,9 +194,11 @@ async def test_priority_queue_buffer_reuse(self): use_priority_queue=True, reuse_cooldown_time=0.5, replay_buffer_kwargs={"priority_fn": "linear_decay", "decay": 0.6}, + batch_size=self.train_batch_size, ) - writer = QueueWriter(meta, self.config) - reader = QueueReader(meta, self.config) + config = config.to_storage_config() + writer = QueueWriter(config) + reader = QueueReader(config) for i in range(4): writer.write( [ @@ -302,7 +310,7 @@ def replace_call(): async def test_priority_queue_reuse_count_control(self): # test experience replay with linear decay and use count control - meta = StorageConfig( + config = ExperienceBufferConfig( name="test_buffer_small", schema_type="experience", storage_type=StorageType.QUEUE, @@ -317,9 +325,11 @@ async def test_priority_queue_reuse_count_control(self): "use_count_limit": 2, "sigma": 0.0, }, + batch_size=self.train_batch_size, ) - writer = QueueWriter(meta, self.config) - reader = QueueReader(meta, self.config) + config = config.to_storage_config() + writer = QueueWriter(config) + reader = QueueReader(config) for i in range(4): writer.write( [ @@ -408,8 +418,5 @@ def setUp(self): self.put_batch_size = 2 self.train_batch_size = 4 - self.config = BufferConfig( - train_batch_size=self.train_batch_size, - ) if os.path.exists(BUFFER_FILE_PATH): os.remove(BUFFER_FILE_PATH) diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index afe0e02c1c..d3d7fd47ce 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -6,7 +6,7 @@ from tests.tools import RayUnittestBaseAysnc from trinity.buffer.reader.sql_reader import SQLReader from trinity.buffer.writer.sql_writer import SQLWriter -from trinity.common.config import StorageConfig +from trinity.common.config import ExperienceBufferConfig from trinity.common.constants import StorageType from trinity.common.experience import Experience @@ -18,15 +18,15 @@ async def test_sql_buffer_read_write(self) -> None: total_num = 8 put_batch_size = 2 read_batch_size = 4 - meta = StorageConfig( + config = ExperienceBufferConfig( name="test_buffer", schema_type="experience", path=f"sqlite:///{db_path}", storage_type=StorageType.SQL, batch_size=read_batch_size, ) - sql_writer = SQLWriter(meta) - sql_reader = SQLReader(meta) + sql_writer = SQLWriter(config.to_storage_config()) + sql_reader = SQLReader(config.to_storage_config()) exps = [ Experience( tokens=torch.tensor([float(j) for j in range(i + 1)]), diff --git a/tests/buffer/task_storage_test.py b/tests/buffer/task_storage_test.py index b0adda2aff..bcb1767ef2 100644 --- a/tests/buffer/task_storage_test.py +++ b/tests/buffer/task_storage_test.py @@ -33,19 +33,19 @@ def test_read_task(self, storage_type, is_eval, offset): config.buffer.explorer_input.taskset = get_unittest_dataset_config( "countdown" ) # 17 samples - config.buffer.batch_size = batch_size config.buffer.explorer_input.taskset.storage_type = storage_type config.buffer.explorer_input.taskset.is_eval = is_eval config.buffer.explorer_input.taskset.index = offset + config.buffer.explorer_input.taskset.batch_size = batch_size if storage_type == StorageType.SQL: dataset = datasets.load_dataset( config.buffer.explorer_input.taskset.path, split="train" ) config.buffer.explorer_input.taskset.path = f"sqlite:///{db_path}" SQLTaskStorage.load_from_dataset( - dataset, config.buffer.explorer_input.taskset, config.buffer + dataset, config.buffer.explorer_input.taskset.to_storage_config() ) - reader = get_buffer_reader(config.buffer.explorer_input.taskset, config.buffer) + reader = get_buffer_reader(config.buffer.explorer_input.taskset) tasks = [] try: while True: diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 27f5b0d455..565af64f22 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -21,9 +21,8 @@ get_unittest_dataset_config, ) from trinity.buffer import get_buffer_reader -from trinity.buffer.utils import default_storage_path from trinity.cli.launcher import explore, run_stage -from trinity.common.config import StorageConfig +from trinity.common.config import ExperienceBufferConfig from trinity.common.constants import StorageType from trinity.explorer.explorer import Explorer from trinity.manager.state_manager import StateManager @@ -116,10 +115,7 @@ def test_explorer(self): self.assertTrue(count >= 0) self.assertTrue(count <= 2 * 4) # repeat_times * batch_size self.assertTrue(count % 2 == 0) # should be multiple of repeat_times - - exp_save_path = default_storage_path( - self.config.buffer.trainer_input.experience_buffer, self.config.buffer - ) + exp_save_path = self.config.buffer.trainer_input.experience_buffer.path with open(exp_save_path, "r", encoding="utf-8") as f: lines = f.readlines() self.assertTrue(len(lines) <= 4 * 2 * 4) # step * repeat_times * batch_size @@ -169,7 +165,7 @@ def setUp(self): self.config.checkpoint_root_dir = get_checkpoint_path() self.config.explorer.api_port = 8010 self.config.explorer.service_status_check_interval = 30 - self.config.buffer.trainer_input.experience_buffer = StorageConfig( + self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( name="experience_buffer", storage_type=StorageType.SQL, ) diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 446b191e6b..d53d492046 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -7,7 +7,7 @@ import torch from tests.tools import get_template_config -from trinity.common.config import StorageConfig +from trinity.common.config import ExperienceBufferConfig from trinity.common.constants import StorageType from trinity.common.experience import EID, Experience from trinity.common.models.model import InferenceModel @@ -237,7 +237,7 @@ def setUp(self): self.config.buffer.pad_token_id = 0 self.config.buffer.explorer_output = ( self.config.buffer.trainer_input.experience_buffer - ) = StorageConfig( + ) = ExperienceBufferConfig( name="test", storage_type=StorageType.QUEUE, schema_type="experience", diff --git a/tests/manager/synchronizer_test.py b/tests/manager/synchronizer_test.py index e4b1801ace..b8f9a9bdb8 100644 --- a/tests/manager/synchronizer_test.py +++ b/tests/manager/synchronizer_test.py @@ -23,7 +23,7 @@ ) from trinity.algorithm.algorithm import ALGORITHM_TYPE from trinity.cli.launcher import both, explore, train -from trinity.common.config import Config, StorageConfig +from trinity.common.config import Config, ExperienceBufferConfig from trinity.common.constants import StorageType, SyncMethod, SyncStyle from trinity.explorer.explorer import Explorer from trinity.trainer.trainer import Trainer @@ -130,7 +130,7 @@ def test_synchronizer(self): config.cluster.node_num = 1 config.model.model_path = get_model_path() config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") - config.buffer.trainer_input.experience_buffer = StorageConfig( + config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( name="exp_buffer", storage_type=StorageType.QUEUE, ) @@ -149,7 +149,7 @@ def test_synchronizer(self): explorer1_config.explorer.name = "explorer1" explorer1_config.explorer.rollout_model.engine_num = 1 explorer1_config.explorer.rollout_model.tensor_parallel_size = 1 - explorer1_config.buffer.explorer_output = StorageConfig( + explorer1_config.buffer.explorer_output = ExperienceBufferConfig( name="exp_buffer", storage_type=StorageType.QUEUE, ) @@ -253,7 +253,7 @@ def test_synchronizer(self): config.cluster.node_num = 1 config.model.model_path = get_model_path() config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") - config.buffer.trainer_input.experience_buffer = StorageConfig( + config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( name="exp_buffer", storage_type=StorageType.QUEUE, ) @@ -273,7 +273,7 @@ def test_synchronizer(self): explorer1_config.explorer.name = "explorer1" explorer1_config.explorer.rollout_model.engine_num = 1 explorer1_config.explorer.rollout_model.tensor_parallel_size = 1 - explorer1_config.buffer.explorer_output = StorageConfig( + explorer1_config.buffer.explorer_output = ExperienceBufferConfig( name="exp_buffer", storage_type=StorageType.QUEUE, ) @@ -354,7 +354,7 @@ def test_synchronizer(self): config.trainer.total_steps = self.max_steps config.model.model_path = get_model_path() config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") - config.buffer.trainer_input.experience_buffer = StorageConfig( + config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( name="exp_buffer", storage_type=StorageType.QUEUE, ) diff --git a/tests/service/data_juicer_test.py b/tests/service/data_juicer_test.py index 096cc62a36..2fdd89ad85 100644 --- a/tests/service/data_juicer_test.py +++ b/tests/service/data_juicer_test.py @@ -14,14 +14,14 @@ import torch from jsonargparse import Namespace -from tests.tools import RayUnittestBase, RayUnittestBaseAysnc, get_template_config +from tests.tools import RayUnittestBase, get_template_config from trinity.buffer.buffer import get_buffer_reader from trinity.buffer.pipelines import ExperiencePipeline, check_and_run_task_pipeline from trinity.common.config import ( DataJuicerServiceConfig, OperatorConfig, - StorageConfig, TaskPipelineConfig, + TasksetConfig, ) from trinity.common.experience import Experience from trinity.service.data_juicer.client import DataJuicerClient @@ -138,7 +138,10 @@ def start_server(port): self.assertIsNone(client.server) -class TestDataJuicerExperiencePipeline(RayUnittestBaseAysnc): +class TestDataJuicerExperiencePipeline(unittest.IsolatedAsyncioTestCase): + def tearDown(self): + ray.shutdown() + async def test_data_juicer_operators(self): config = get_template_config() config.service.data_juicer = DataJuicerServiceConfig( @@ -201,7 +204,7 @@ async def test_data_juicer_operators(self): ] metrics = await pipeline.process.remote(exps) self.assertIsInstance(metrics, dict) - reader = get_buffer_reader(config.buffer.trainer_input.experience_buffer, config.buffer) + reader = get_buffer_reader(config.buffer.trainer_input.experience_buffer) filtered_exps = reader.read(batch_size=2) self.assertEqual(len(filtered_exps), 2) with self.assertRaises(TimeoutError): @@ -257,7 +260,7 @@ def test_data_juicer_task_pipeline(self): ], target_fields=["question", "answer"], ) - config.buffer.explorer_input.taskset = StorageConfig( + config.buffer.explorer_input.taskset = TasksetConfig( name="taskset", path=TASKSET_OUTPUT_DIR, ) diff --git a/tests/tools.py b/tests/tools.py index 64a29819f8..8d8cab98a5 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -8,9 +8,10 @@ from trinity.common.config import ( Config, + ExperienceBufferConfig, FormatConfig, LoRAConfig, - StorageConfig, + TasksetConfig, load_config, ) from trinity.common.constants import ( @@ -74,12 +75,10 @@ def get_lora_config() -> LoRAConfig: return LoRAConfig(name="lora", lora_rank=16, lora_alpha=16) -def get_unittest_dataset_config( - dataset_name: str = "countdown", split: str = "train" -) -> StorageConfig: +def get_unittest_dataset_config(dataset_name: str = "countdown", split: str = "train"): if dataset_name == "countdown" or dataset_name == "copy_countdown": # Countdown dataset with 17 samples - return StorageConfig( + return TasksetConfig( name=dataset_name, path=os.path.join(os.path.dirname(__file__), "template", "data", "countdown"), split=split, @@ -93,7 +92,7 @@ def get_unittest_dataset_config( ) elif dataset_name in {"eval_short", "eval_long"}: # Eval_short dataset with 2 samples, eval_long dataset with 8 samples - return StorageConfig( + return TasksetConfig( name=dataset_name, path=os.path.join(os.path.dirname(__file__), "template", "data", dataset_name), split="test", @@ -106,7 +105,7 @@ def get_unittest_dataset_config( ) elif dataset_name == "gsm8k": # GSM8K dataset with 16 samples - return StorageConfig( + return TasksetConfig( name=dataset_name, path=os.path.join(os.path.dirname(__file__), "template", "data", "gsm8k"), split="train", @@ -119,7 +118,7 @@ def get_unittest_dataset_config( ) elif dataset_name == "sft_for_gsm8k": # SFT dataset with 8 samples - return StorageConfig( + return ExperienceBufferConfig( name=dataset_name, path=os.path.join(os.path.dirname(__file__), "template", "data", "sft_for_gsm8k"), split="train", @@ -132,7 +131,7 @@ def get_unittest_dataset_config( ) elif dataset_name == "sft_with_tools": # SFT_with_tools dataset with 4 samples - return StorageConfig( + return ExperienceBufferConfig( name=dataset_name, path=os.path.join(os.path.dirname(__file__), "template", "data", "sft_with_tools"), split="train", @@ -145,7 +144,7 @@ def get_unittest_dataset_config( ) elif dataset_name == "dpo": # HumanLike DPO dataset with 17 samples - return StorageConfig( + return ExperienceBufferConfig( name=dataset_name, path=os.path.join(os.path.dirname(__file__), "template", "data", "human_like"), split="train", @@ -158,7 +157,7 @@ def get_unittest_dataset_config( ) elif dataset_name == "geometry": # Multi-modal geometry dataset with 8 samples - return StorageConfig( + return TasksetConfig( name=dataset_name, path=os.path.join(os.path.dirname(__file__), "template", "data", "geometry"), split="train", diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index acb4325fce..3eaff088cc 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -27,9 +27,9 @@ AlgorithmConfig, BufferConfig, Config, + ExperienceBufferConfig, ExplorerInput, StageConfig, - StorageConfig, TrainerInput, ) from trinity.common.constants import ( @@ -287,7 +287,7 @@ def test_trainer(self, mock_load): batch_size=4, explorer_input=ExplorerInput(taskset=get_unittest_dataset_config("gsm8k")), trainer_input=TrainerInput( - experience_buffer=StorageConfig( + experience_buffer=ExperienceBufferConfig( name="test_queue_storage", max_read_timeout=20, storage_type=StorageType.QUEUE, @@ -495,7 +495,7 @@ def test_fully_async_mode(self): config.cluster.node_num = 1 config.model.model_path = get_model_path() config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") - config.buffer.trainer_input.experience_buffer = StorageConfig( + config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( name="exp_buffer", storage_type=StorageType.QUEUE, use_priority_queue=self.use_priority_queue, @@ -524,7 +524,7 @@ def test_fully_async_mode(self): config.cluster.node_num = 1 explorer1_config.explorer.rollout_model.engine_num = 1 explorer1_config.explorer.rollout_model.tensor_parallel_size = 1 - explorer1_config.buffer.trainer_input.experience_buffer = StorageConfig( + explorer1_config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( name="exp_buffer", storage_type=StorageType.QUEUE, ) diff --git a/trinity/buffer/buffer.py b/trinity/buffer/buffer.py index 143130ab7c..eb47af2806 100644 --- a/trinity/buffer/buffer.py +++ b/trinity/buffer/buffer.py @@ -1,13 +1,21 @@ # -*- coding: utf-8 -*- """The buffer module""" +from typing import Union + from trinity.buffer.buffer_reader import BufferReader from trinity.buffer.buffer_writer import BufferWriter -from trinity.common.config import StorageConfig +from trinity.common.config import ExperienceBufferConfig, StorageConfig, TasksetConfig from trinity.common.constants import StorageType +BufferStorageConfig = Union[TasksetConfig, ExperienceBufferConfig, StorageConfig] + -def get_buffer_reader(storage_config: StorageConfig) -> BufferReader: +def get_buffer_reader(config: BufferStorageConfig) -> BufferReader: """Get a buffer reader for the given dataset name.""" + if not isinstance(config, StorageConfig): + storage_config: StorageConfig = config.to_storage_config() + else: + storage_config = config if storage_config.storage_type == StorageType.SQL: from trinity.buffer.reader.sql_reader import SQLReader @@ -32,8 +40,12 @@ def get_buffer_reader(storage_config: StorageConfig) -> BufferReader: raise ValueError(f"{storage_config.storage_type} not supported.") -def get_buffer_writer(storage_config: StorageConfig) -> BufferWriter: +def get_buffer_writer(config: BufferStorageConfig) -> BufferWriter: """Get a buffer writer for the given dataset name.""" + if not isinstance(config, StorageConfig): + storage_config: StorageConfig = config.to_storage_config() + else: + storage_config = config if storage_config.storage_type == StorageType.SQL: from trinity.buffer.writer.sql_writer import SQLWriter diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index 46036bfe86..e36dfc0ce9 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -16,7 +16,7 @@ class QueueReader(BufferReader): def __init__(self, config: StorageConfig): assert config.storage_type == StorageType.QUEUE self.timeout = config.max_read_timeout - self.read_batch_size = config.train_batch_size + self.read_batch_size = config.batch_size self.queue = QueueStorage.get_wrapper(config) def read(self, batch_size: Optional[int] = None) -> List: diff --git a/trinity/buffer/storage/file.py b/trinity/buffer/storage/file.py index 3de80bcde7..75b3499b3b 100644 --- a/trinity/buffer/storage/file.py +++ b/trinity/buffer/storage/file.py @@ -5,7 +5,6 @@ import ray -from trinity.buffer.utils import default_storage_path from trinity.common.config import StorageConfig from trinity.common.experience import EID, Experience from trinity.common.workflows import Task @@ -35,7 +34,7 @@ class FileStorage: def __init__(self, config: StorageConfig) -> None: if not config.path: - config.path = default_storage_path(config.name, config.storage_type) + raise ValueError("`path` is required for FILE storage type.") ext = os.path.splitext(config.path)[-1] if ext != ".jsonl" and ext != ".json": raise ValueError(f"File path must end with '.json' or '.jsonl', got {config.path}") diff --git a/trinity/buffer/storage/sql.py b/trinity/buffer/storage/sql.py index 790813570d..ea3bf4342c 100644 --- a/trinity/buffer/storage/sql.py +++ b/trinity/buffer/storage/sql.py @@ -11,7 +11,7 @@ from trinity.buffer.schema import init_engine from trinity.buffer.schema.formatter import FORMATTER, TaskFormatter -from trinity.buffer.utils import default_storage_path, retry_session +from trinity.buffer.utils import retry_session from trinity.common.config import StorageConfig from trinity.common.experience import Experience from trinity.common.rewards import REWARD_FUNCTIONS @@ -33,9 +33,7 @@ class SQLStorage: def __init__(self, config: StorageConfig) -> None: self.logger = get_logger(f"sql_{config.name}", in_ray_actor=True) if not config.path: - config.path = default_storage_path( - storage_name=config.name, storage_type=config.storage_type - ) + raise ValueError("`path` is required for SQL storage type.") self.engine, self.table_model_cls = init_engine( db_url=config.path, table_name=config.name, diff --git a/trinity/buffer/utils.py b/trinity/buffer/utils.py index 9c9a32dc8c..b27ba662bc 100644 --- a/trinity/buffer/utils.py +++ b/trinity/buffer/utils.py @@ -1,8 +1,6 @@ -import os import time from contextlib import contextmanager -from trinity.common.constants import CHECKPOINT_JOB_DIR_ENV_VAR, StorageType from trinity.utils.log import get_logger @@ -34,18 +32,3 @@ def retry_session(session_maker, max_retry_times: int, max_retry_interval: float raise e finally: session.close() - - -def default_storage_path(storage_name: str, storage_type: StorageType) -> str: - checkpoint_dir = os.environ.get(CHECKPOINT_JOB_DIR_ENV_VAR, None) - if checkpoint_dir is None: - raise ValueError( - f"Environment variable {CHECKPOINT_JOB_DIR_ENV_VAR} is not set. " - "This should not happen when using `trinity run` command." - ) - storage_dir = os.path.join(checkpoint_dir, "buffer") - os.makedirs(storage_dir, exist_ok=True) - if storage_type == StorageType.SQL: - return "sqlite:///" + os.path.join(storage_dir, f"{storage_name}.db") - else: - return os.path.join(storage_dir, f"{storage_name}.jsonl") diff --git a/trinity/common/config.py b/trinity/common/config.py index 1337ed2e84..e860c1900e 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -125,8 +125,8 @@ class ReplayBufferConfig: @dataclass class StorageConfig: - """Storage config. - Used for both taskset and experience buffer. + """Storage config for both taskset and experience buffer. + Not visible to users directly. Please use ExperienceBufferConfig or TasksetConfig instead. """ name: str = "" @@ -183,10 +183,14 @@ class StorageConfig: # ! DO NOT SET, automatically set from buffer.batch_size / train_batch_size batch_size: int = 0 + # ! DO NOT SET, automatically set from model.model_path + tokenizer_path: Optional[str] = None + # ! DO NOT SET, automatically set corresponding to train/eval is_eval: bool = False +@dataclass class TasksetConfig: name: str = "" storage_type: StorageType = StorageType.FILE @@ -217,6 +221,10 @@ class TasksetConfig: is_eval: bool = False # ! DO NOT SET, automatically set from buffer.batch_size batch_size: int = 0 + # ! DO NOT SET, automatically set from buffer.total_epochs + total_epochs: int = 1 # automatically set + # ! DO NOT SET, automatically set from buffer.total_steps + total_steps: Optional[int] = None # automatically set def to_storage_config(self) -> StorageConfig: storage_config = StorageConfig( @@ -224,7 +232,6 @@ def to_storage_config(self) -> StorageConfig: storage_type=self.storage_type, path=self.path, repeat_times=self.repeat_times, - index=self.index, split=self.split, subset_name=self.subset_name, format=self.format, @@ -236,20 +243,22 @@ def to_storage_config(self) -> StorageConfig: workflow_args=self.workflow_args, reward_fn_args=self.reward_fn_args, enable_progress_bar=self.enable_progress_bar, + index=self.index, is_eval=self.is_eval, batch_size=self.batch_size, + total_epochs=self.total_epochs, + total_steps=self.total_steps, ) return storage_config +@dataclass class ExperienceBufferConfig: + """Storage Config for trainer input experience buffer.""" + name: str = "" storage_type: StorageType = StorageType.FILE - - # used for StorageType.FILE - split: str = "train" - subset_name: Optional[str] = None - format: FormatConfig = field(default_factory=FormatConfig) + path: Optional[str] = None # used for StorageType.QUEUE capacity: int = 10000 @@ -264,11 +273,47 @@ class ExperienceBufferConfig: max_retry_times: int = 3 max_retry_interval: int = 1 - # ! DO NOT SET, automatically load from checkpoint - index: int = 0 + # used for StorageType.FILE + split: str = "train" + subset_name: Optional[str] = None + format: FormatConfig = field(default_factory=FormatConfig) + # ! DO NOT SET, automatically set + schema_type: Optional[str] = None + # ! DO NOT SET + index: int = 0 # ! DO NOT SET, automatically set from buffer.batch_size batch_size: int = 0 + # ! DO NOT SET, automatically set from model.model_path + tokenizer_path: Optional[str] = None + # ! DO NOT SET, automatically set from buffer.total_epochs + total_epochs: int = 1 # automatically set + # ! DO NOT SET, automatically set from buffer.total_steps + total_steps: Optional[int] = None # automatically set + + def to_storage_config(self) -> StorageConfig: + storage_config = StorageConfig( + name=self.name, + storage_type=self.storage_type, + path=self.path, + capacity=self.capacity, + max_read_timeout=self.max_read_timeout, + use_priority_queue=self.use_priority_queue, + reuse_cooldown_time=self.reuse_cooldown_time, + replay_buffer_kwargs=self.replay_buffer_kwargs, + max_retry_times=self.max_retry_times, + max_retry_interval=self.max_retry_interval, + split=self.split, + subset_name=self.subset_name, + format=self.format, + schema_type=self.schema_type, + index=self.index, + batch_size=self.batch_size, + tokenizer_path=self.tokenizer_path, + total_epochs=self.total_epochs, + total_steps=self.total_steps, + ) + return storage_config @dataclass @@ -295,9 +340,9 @@ class ExperiencePipelineConfig: # A dictionary of input buffers, buffers are indexed by their names. # users only need to set extra buffers here - inputs: Dict[str, StorageConfig] = field(default_factory=dict) + inputs: Dict[str, ExperienceBufferConfig] = field(default_factory=dict) # The output buffer will automatically set to the trainer input buffer, so we do not need to set it here. - output: Optional[StorageConfig] = None + output: Optional[ExperienceBufferConfig] = None @Experimental @@ -320,7 +365,7 @@ class TaskPipelineConfig: # e.g., /path/to/file.jsonl or /path/to/file.parquet, not a directory or huggingface path inputs: List[str] = field(default_factory=list) # Output task buffer, if not set, use `buffer.explorer_input.taskset`. In most cases, users do not need to set this field. - output: Optional[StorageConfig] = None + output: Optional[TasksetConfig] = None # The list of fields extracted from the input tasksets and processed into the output taskset target_fields: List[str] = field(default_factory=list) @@ -481,8 +526,8 @@ class ClusterConfig: class ExplorerInput: """Config for explorer input.""" - taskset: StorageConfig = field(default_factory=StorageConfig) - eval_tasksets: List[StorageConfig] = field(default_factory=list) + taskset: TasksetConfig = field(default_factory=TasksetConfig) + eval_tasksets: List[TasksetConfig] = field(default_factory=list) # The following args provide default values for the corresponding args in `taskset` and `eval_tasksets` default_workflow_type: Optional[str] = None default_eval_workflow_type: Optional[str] = None @@ -495,10 +540,10 @@ class TrainerInput: # The main experience buffer to be used in trainer # Commonly, it is also the output buffer of the Explorer - experience_buffer: Optional[StorageConfig] = None + experience_buffer: Optional[ExperienceBufferConfig] = None # Some auxiliary buffers to facilitate training (e.g., data mixing) - auxiliary_buffers: Dict[str, StorageConfig] = field(default_factory=dict) + auxiliary_buffers: Dict[str, ExperienceBufferConfig] = field(default_factory=dict) @dataclass @@ -763,11 +808,7 @@ def _check_interval(self) -> None: f"`eval_interval` is not a multiple of `sync_interval`; adjusted to the nearest integer={self.explorer.eval_interval}." ) - def _check_buffer(self) -> None: # noqa: C901 - # TODO: split this function into different buffer read/writer - # check explorer_input - trainer_input = self.buffer.trainer_input - experience_buffer = trainer_input.experience_buffer + def _check_explorer_input(self) -> None: explorer_input = self.buffer.explorer_input taskset = explorer_input.taskset @@ -783,27 +824,18 @@ def _check_buffer(self) -> None: # noqa: C901 "`buffer.explorer_input.taskset.repeat_times` is set to `algorithm.repeat_times`" f" (={self.algorithm.repeat_times})." ) - if self.mode == "train": - assert ( - experience_buffer is not None - ), "`buffer.trainer_input.experience_buffer` is required when `mode` is `train`." - experience_buffer.total_epochs = self.buffer.total_epochs - experience_buffer.total_steps = self.buffer.total_steps - else: - taskset.total_epochs = self.buffer.total_epochs - taskset.total_steps = self.buffer.total_steps + taskset.batch_size = self.buffer.batch_size set_if_none(taskset, "default_workflow_type", explorer_input.default_workflow_type) set_if_none(taskset, "default_reward_fn_type", explorer_input.default_reward_fn_type) set_if_none(taskset, "ray_namespace", self.ray_namespace) set_if_none(taskset.rollout_args, "max_tokens", self.model.max_response_tokens) - remained_tasksets = [] for idx, dataset in enumerate(explorer_input.eval_tasksets): if not dataset.path: - logger.warning(f"Eval dataset [{dataset}]'s path is not configured. Skip.") - continue + raise ValueError(f"Eval dataset [{dataset}]'s path is not configured.") dataset.is_eval = True + dataset.batch_size = self.buffer.batch_size if not dataset.name: dataset.name = f"eval_taskset_{idx}" set_if_none(dataset, "repeat_times", 1) @@ -813,12 +845,17 @@ def _check_buffer(self) -> None: # noqa: C901 set_if_none(dataset, "default_reward_fn_type", explorer_input.default_reward_fn_type) set_if_none(dataset, "ray_namespace", self.ray_namespace) set_if_none(dataset.rollout_args, "max_tokens", self.model.max_response_tokens) - remained_tasksets.append(dataset) - explorer_input.eval_tasksets = remained_tasksets - # check trainer_input.experience_buffer + if self.mode != "train": + taskset.total_epochs = self.buffer.total_epochs + taskset.total_steps = self.buffer.total_steps + + def _check_trainer_input(self) -> None: + trainer_input = self.buffer.trainer_input + experience_buffer = trainer_input.experience_buffer + if experience_buffer is None: - experience_buffer = trainer_input.experience_buffer = StorageConfig( + experience_buffer = trainer_input.experience_buffer = ExperienceBufferConfig( name="experience_buffer", storage_type=StorageType.QUEUE, ) @@ -831,27 +868,49 @@ def _check_buffer(self) -> None: # noqa: C901 from trinity.algorithm.algorithm import ALGORITHM_TYPE - experience_buffer.schema_type = ALGORITHM_TYPE.get(self.algorithm.algorithm_type).schema + if not experience_buffer.path: + experience_buffer.path = os.path.join( + self.buffer.cache_dir, "trainer_experience_buffer" # type: ignore[arg-type] + ) + logger.warning( + f"Auto set `buffer.trainer_input.experience_buffer.path` to {experience_buffer.path}" + ) + experience_buffer.schema_type = ALGORITHM_TYPE.get(self.algorithm.algorithm_type).schema + experience_buffer.batch_size = self.buffer.train_batch_size + experience_buffer.tokenizer_path = self.model.model_path set_if_none(experience_buffer, "ray_namespace", self.ray_namespace) set_if_none(experience_buffer.format, "chat_template", self.model.custom_chat_template) + for aux_name, aux_buffer in trainer_input.auxiliary_buffers.items(): + aux_buffer.batch_size = self.buffer.train_batch_size + aux_buffer.tokenizer_path = self.model.model_path + set_if_none(aux_buffer, "ray_namespace", self.ray_namespace) + if aux_buffer.path is None or aux_buffer.path == "": + raise ValueError( + f"`buffer.trainer_input.auxiliary_buffers[{aux_name}].path` is required, " + f"please set it to the path of the auxiliary buffer." + ) - # create buffer.cache_dir at ///buffer - self.buffer.cache_dir = os.path.abspath(os.path.join(self.checkpoint_job_dir, "buffer")) - try: - os.makedirs(self.buffer.cache_dir, exist_ok=True) - except Exception: - logger.warning( - f"Failed to create buffer dir {self.buffer.cache_dir}, please check " - f"your checkpoint directory: {self.checkpoint_job_dir}" - ) + if self.mode == "train": + assert ( + experience_buffer is not None + ), "`buffer.trainer_input.experience_buffer` is required when `mode` is `train`." + experience_buffer.total_epochs = self.buffer.total_epochs + experience_buffer.total_steps = self.buffer.total_steps + + def _default_storage_path(self, storage_type: StorageType, name: str) -> str: + if storage_type == StorageType.SQL: + return "sqlite:///" + os.path.join(self.buffer.cache_dir, f"{name}.db") # type: ignore[arg-type] + else: + return os.path.join(self.buffer.cache_dir, f"{name}.jsonl") # type: ignore[arg-type] + def _check_data_processor(self) -> None: # check input/output buffers in pipelines experience_pipeline = self.data_processor.experience_pipeline if experience_pipeline is not None: if experience_pipeline.save_input and experience_pipeline.input_save_path is None: experience_pipeline.input_save_path = os.path.join( - self.buffer.cache_dir, "explorer_output.jsonl" + self.buffer.cache_dir, "explorer_output.jsonl" # type: ignore[arg-type] ) logger.info( f"Auto set `data_processor.experience_pipeline.input_save_path` to {experience_pipeline.input_save_path}" @@ -860,18 +919,13 @@ def _check_buffer(self) -> None: # noqa: C901 task_pipeline = self.data_processor.task_pipeline if task_pipeline is not None: if task_pipeline.output is None: - if taskset.path is not None: - task_pipeline.output = taskset - elif ( - experience_buffer.schema_type in {"dpo", "sft"} - and experience_buffer.path is not None - ): - task_pipeline.output = experience_buffer + if self.mode != "train": + task_pipeline.output = self.buffer.explorer_input.taskset + elif self.mode == "train" and self.algorithm.algorithm_type in {"dpo", "sft"}: + task_pipeline.output = self.buffer.trainer_input.experience_buffer else: raise ValueError( - "`data_processor.task_pipeline.output` is required when both " - "`buffer.explorer_input.taskset.path` and `buffer.trainer_input.experience_buffer.path` are " - "None" + "`data_processor.task_pipeline.output` is missing. Please set it to the desired output storage config." ) if task_pipeline.output.path and os.path.exists(task_pipeline.output.path): raise ValueError( @@ -879,6 +933,7 @@ def _check_buffer(self) -> None: # noqa: C901 "Please choose a different output path to avoid overwriting." ) + def _check_buffer(self) -> None: # noqa: C901 # check train_batch_size if not self.buffer.train_batch_size: if self.mode == "train" or self.algorithm.algorithm_type in ["sft", "dpo"]: @@ -891,6 +946,16 @@ def _check_buffer(self) -> None: # noqa: C901 ) self.buffer.train_batch_size = self.buffer.batch_size * self.algorithm.repeat_times + # create buffer.cache_dir at ///buffer + self.buffer.cache_dir = os.path.abspath(os.path.join(self.checkpoint_job_dir, "buffer")) + try: + os.makedirs(self.buffer.cache_dir, exist_ok=True) + except Exception as e: + raise RuntimeError( + f"Failed to create buffer dir {self.buffer.cache_dir}, please check " + f"your checkpoint directory: {self.checkpoint_job_dir}" + ) from e + # set pad_token_id / tokenizer_path if self.buffer.pad_token_id is None: from transformers import AutoTokenizer @@ -908,7 +973,10 @@ def _check_buffer(self) -> None: # noqa: C901 except Exception: logger.warning(f"Failed to get pad token id from model {self.model.model_path}") self.buffer.pad_token_id = 0 - self.buffer.tokenizer_path = self.model.model_path + + self._check_explorer_input() + self._check_trainer_input() + self._check_data_processor() def _check_algorithm(self) -> None: from trinity.algorithm import ( From 94cc301968384f88331400c4cb9b221522ebfd90 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 23 Oct 2025 18:08:15 +0800 Subject: [PATCH 03/21] fix buffer tests --- tests/buffer/experience_pipeline_test.py | 14 ++++++++++---- tests/buffer/file_test.py | 1 + trinity/common/config.py | 9 ++++++--- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/tests/buffer/experience_pipeline_test.py b/tests/buffer/experience_pipeline_test.py index 3adcbccab1..53a3babbcb 100644 --- a/tests/buffer/experience_pipeline_test.py +++ b/tests/buffer/experience_pipeline_test.py @@ -7,7 +7,11 @@ from tests.tools import RayUnittestBaseAysnc, get_template_config from trinity.buffer import get_buffer_reader from trinity.buffer.pipelines.experience_pipeline import ExperiencePipeline -from trinity.common.config import ExperiencePipelineConfig, OperatorConfig +from trinity.common.config import ( + ExperienceBufferConfig, + ExperiencePipelineConfig, + OperatorConfig, +) from trinity.common.experience import EID, Experience BUFFER_FILE_PATH = os.path.join(os.path.dirname(__file__), "test_pipeline_buffer.jsonl") @@ -51,9 +55,11 @@ async def test_experience_pipeline(self): config.algorithm.advantage_fn = ( "grpo" # grpo will add an operator at the end of the pipeline ) + config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( + name="pipeline_test_experience_buffer", + max_read_timeout=3, + ) config.check_and_update() - config.buffer.trainer_input.experience_buffer.name = "pipeline_test_experience_buffer" - config.buffer.trainer_input.experience_buffer.max_read_timeout = 3 pipeline = ( ray.remote(ExperiencePipeline) @@ -70,7 +76,7 @@ async def test_experience_pipeline(self): ) # first experience of each task will be filtered out by the reward filter # tests - reader = get_buffer_reader(config.buffer.trainer_input.experience_buffer, config.buffer) + reader = get_buffer_reader(config.buffer.trainer_input.experience_buffer) exps = await reader.read_async(batch_size=task_num * (repeat_times - 1)) self.assertEqual(len(exps), task_num * (repeat_times - 1)) with self.assertRaises(TimeoutError): diff --git a/tests/buffer/file_test.py b/tests/buffer/file_test.py index 995f4df25a..06f538f3f1 100644 --- a/tests/buffer/file_test.py +++ b/tests/buffer/file_test.py @@ -102,6 +102,7 @@ async def test_file_writer(self): def setUp(self): self.config = get_template_config() + self.config.mode = "explore" self.config.checkpoint_root_dir = get_checkpoint_path() dataset_config = get_unittest_dataset_config("countdown", "train") self.config.buffer.explorer_input.taskset = dataset_config diff --git a/trinity/common/config.py b/trinity/common/config.py index e860c1900e..0ac94a7e27 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -866,16 +866,19 @@ def _check_trainer_input(self) -> None: ) experience_buffer.storage_type = StorageType.QUEUE - from trinity.algorithm.algorithm import ALGORITHM_TYPE + if not experience_buffer.name: + experience_buffer.name = "experience_buffer" if not experience_buffer.path: - experience_buffer.path = os.path.join( - self.buffer.cache_dir, "trainer_experience_buffer" # type: ignore[arg-type] + experience_buffer.path = self._default_storage_path( + experience_buffer.storage_type, experience_buffer.name ) logger.warning( f"Auto set `buffer.trainer_input.experience_buffer.path` to {experience_buffer.path}" ) + from trinity.algorithm.algorithm import ALGORITHM_TYPE + experience_buffer.schema_type = ALGORITHM_TYPE.get(self.algorithm.algorithm_type).schema experience_buffer.batch_size = self.buffer.train_batch_size experience_buffer.tokenizer_path = self.model.model_path From a07d5880e4cde9e61915dac7ba81d69b9e122d51 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 23 Oct 2025 18:12:19 +0800 Subject: [PATCH 04/21] fix explorer tests --- docs/sphinx_doc/source/tutorial/example_mix_algo.md | 5 ++--- docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md | 5 ++--- trinity/algorithm/sample_strategy/mix_sample_strategy.py | 4 +--- trinity/algorithm/sample_strategy/sample_strategy.py | 4 +--- trinity/explorer/explorer.py | 2 +- 5 files changed, 7 insertions(+), 13 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index 56312a7b98..528bbd29a4 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -88,7 +88,7 @@ class MixSampleStrategy(SampleStrategy): usual_buffer_config = copy.deepcopy(buffer_config) usual_buffer_config.train_batch_size = tot_batch_size - expert_batch_size self.usual_exp_buffer = get_buffer_reader( - buffer_config.trainer_input.experience_buffer, usual_buffer_config # type: ignore + buffer_config.trainer_input.experience_buffer ) if buffer_config.trainer_input.auxiliary_buffers is None: @@ -100,8 +100,7 @@ class MixSampleStrategy(SampleStrategy): expert_buffer_config = copy.deepcopy(buffer_config) expert_buffer_config.train_batch_size = expert_batch_size self.expert_exp_buffer = get_buffer_reader( - buffer_config.trainer_input.auxiliary_buffers[self.sft_dataset_name], - expert_buffer_config, + buffer_config.trainer_input.auxiliary_buffers[self.sft_dataset_name] ) async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: diff --git a/docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md b/docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md index 4c0b575fd5..070152cf72 100644 --- a/docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md @@ -80,7 +80,7 @@ class MixSampleStrategy(SampleStrategy): usual_buffer_config = copy.deepcopy(buffer_config) usual_buffer_config.train_batch_size = tot_batch_size - expert_batch_size self.usual_exp_buffer = get_buffer_reader( - buffer_config.trainer_input.experience_buffer, usual_buffer_config # type: ignore + buffer_config.trainer_input.experience_buffer ) if buffer_config.trainer_input.auxiliary_buffers is None: @@ -92,8 +92,7 @@ class MixSampleStrategy(SampleStrategy): expert_buffer_config = copy.deepcopy(buffer_config) expert_buffer_config.train_batch_size = expert_batch_size self.expert_exp_buffer = get_buffer_reader( - buffer_config.trainer_input.auxiliary_buffers[self.sft_dataset_name], - expert_buffer_config, + buffer_config.trainer_input.auxiliary_buffers[self.sft_dataset_name] ) async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index 7bfa97d7a4..52f2ef0199 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -29,9 +29,7 @@ def __init__(self, buffer_config: BufferConfig, **kwargs): # experience buffer usual_buffer_config = copy.deepcopy(buffer_config) usual_buffer_config.train_batch_size = tot_batch_size - expert_batch_size - self.usual_exp_buffer = get_buffer_reader( - buffer_config.trainer_input.experience_buffer, usual_buffer_config # type: ignore - ) + self.usual_exp_buffer = get_buffer_reader(buffer_config.trainer_input.experience_buffer) # type: ignore[arg-type] if buffer_config.trainer_input.auxiliary_buffers is None: raise ValueError( diff --git a/trinity/algorithm/sample_strategy/sample_strategy.py b/trinity/algorithm/sample_strategy/sample_strategy.py index cae6274411..9d2bddf798 100644 --- a/trinity/algorithm/sample_strategy/sample_strategy.py +++ b/trinity/algorithm/sample_strategy/sample_strategy.py @@ -39,9 +39,7 @@ def default_args(cls) -> dict: class DefaultSampleStrategy(SampleStrategy): def __init__(self, buffer_config: BufferConfig, **kwargs): super().__init__(buffer_config) - self.exp_buffer = get_buffer_reader( - buffer_config.trainer_input.experience_buffer, buffer_config # type: ignore - ) + self.exp_buffer = get_buffer_reader(buffer_config.trainer_input.experience_buffer) # type: ignore[arg-type] async def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]: metrics = {} diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index e4bc5bf969..fb9c0b15ff 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -266,7 +266,7 @@ async def eval(self): self.logger.info( f"Evaluation on {eval_taskset_config.name} at step {self.explore_step_num} started." ) - eval_taskset = get_buffer_reader(eval_taskset_config, self.config.buffer) + eval_taskset = get_buffer_reader(eval_taskset_config) eval_batch_id = f"{self.explore_step_num}/{eval_taskset.name}" self.pending_eval_tasks.append((self.explore_step_num, eval_taskset.name)) while True: From 433008a2f3f5d686553ef609d3282e1df4bed205 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 23 Oct 2025 19:21:57 +0800 Subject: [PATCH 05/21] fix replay buffer config --- benchmark/config/countdown-template.yaml | 6 ++-- benchmark/config/gsm8k-template.yaml | 6 ++-- .../source/tutorial/example_step_wise.md | 8 +++-- .../source/tutorial/trinity_configs.md | 7 ++-- .../source_zh/tutorial/example_step_wise.md | 8 +++-- .../source_zh/tutorial/trinity_configs.md | 7 ++-- .../agentscopev1_websearch_agent.yaml | 3 +- .../alfworld.yaml | 3 +- examples/grpo_email_search/email_search.yaml | 3 +- examples/grpo_rubric_as_reward/rubric.yaml | 3 +- tests/buffer/queue_test.py | 35 +++++++++++-------- tests/trainer/trainer_test.py | 2 +- trinity/buffer/storage/queue.py | 22 +++++++----- trinity/common/config.py | 22 ++++-------- 14 files changed, 76 insertions(+), 59 deletions(-) diff --git a/benchmark/config/countdown-template.yaml b/benchmark/config/countdown-template.yaml index d10eb33573..7e6919dd8e 100644 --- a/benchmark/config/countdown-template.yaml +++ b/benchmark/config/countdown-template.yaml @@ -42,10 +42,10 @@ buffer: experience_buffer: name: experience_buffer storage_type: queue - use_priority_queue: true - replay_buffer_kwargs: + reply_buffer: priority_fn: linear_decay - decay: 0.1 + priority_fn_args: + decay: 0.1 explorer: runner_per_model: 8 max_timeout: 900 diff --git a/benchmark/config/gsm8k-template.yaml b/benchmark/config/gsm8k-template.yaml index 93f42166a6..feb0ce62e5 100644 --- a/benchmark/config/gsm8k-template.yaml +++ b/benchmark/config/gsm8k-template.yaml @@ -47,10 +47,10 @@ buffer: experience_buffer: name: experience_buffer storage_type: queue - use_priority_queue: true - replay_buffer_kwargs: + reply_buffer: priority_fn: linear_decay - decay: 0.1 + priority_fn_args: + decay: 0.1 explorer: runner_per_model: 8 max_timeout: 900 diff --git a/docs/sphinx_doc/source/tutorial/example_step_wise.md b/docs/sphinx_doc/source/tutorial/example_step_wise.md index 3239bbaf70..c9703c905d 100644 --- a/docs/sphinx_doc/source/tutorial/example_step_wise.md +++ b/docs/sphinx_doc/source/tutorial/example_step_wise.md @@ -81,7 +81,7 @@ In general multi-step scenarios, each run may generate various number of experie - `buffer.train_batch_size`: The number of experiences to be sampled from the buffer for training, which can be different from the number of generated experiences in each explore step. -- `buffer.trainer_input.use_priority_queue = true`: Using `PriorityQueue` allows the model to use the experiences with higher priority, which prefers newly-generated experiences by default. +- `buffer.trainer_input.experience_buffer.replay_buffer`: Using `PriorityQueue` allows the model to use the experiences with higher priority, which prefers newly-generated experiences by default. - `synchronizer.sync_style = dynamic_by_explorer`: The explorer determines when to synchronize the model weights with the trainer. @@ -126,7 +126,11 @@ buffer: experience_buffer: name: alfworld_buffer storage_type: queue - use_priority_queue: true + replay_buffer: + enable: true + priority_fn: linear_decay + priority_fn_args: + decay: 0.1 explorer: max_repeat_times_per_runner: 1 runner_per_model: 32 diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index c1af4871b7..6bcf4a20e6 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -339,8 +339,11 @@ buffer: - `enable_concatenated_multi_turn`: Enable concatenated multi-turn SFT data preprocess. Only for `messages` and only take effect with SFT algorithm. - `chat_template`: Specifies the chat template in string format. If not provided, use `model.custom_chat_template`. - `max_read_timeout`: The maximum waiting time (in seconds) to read new experience data. If exceeded, an incomplete batch will be returned directly. Only take effect when `storage_type` is `queue`. Default is 1800 seconds (30 minutes). - - `use_priority_queue`: Only take effect when `storage_type` is `queue`. If set to `True`, the queue will be a priority queue, which allows for prioritizing certain experiences over others. Default is `False`. - - `reuse_cooldown_time`: Only take effect when `storage_type` is `queue` and `use_priority_queue` is `True`. If set, it specifies the cooldown time (in seconds) for reusing experiences. If not specified, the default value is `None`, meaning experiences can not be reused. + - `replay_buffer`: Only take effect when `storage_type` is `queue`. Used to configure the replay buffer for experience reuse. + - `enable`: Whether to enable the replay buffer. Default is `false`. + - `reuse_cooldown_time`: Cooldown time (in seconds) for reusing experiences. If not specified, the default value is `None`, meaning experiences can not be reused. + - `priority_fn`: Experience priority function used to determine the order of experience reuse. Currently supports `linear_decay` and `linear_decay_use_count_control_randomization`. + - `priority_fn_args`: A dictionary of arguments passed to the priority function, specific parameters depend on the selected priority function. - `auxiliary_buffers`: Optional buffers used for trainer. It is a dictionary where each key is the buffer name and the value is the buffer configuration. Each buffer configuration is similar to the `experience_buffer`. --- diff --git a/docs/sphinx_doc/source_zh/tutorial/example_step_wise.md b/docs/sphinx_doc/source_zh/tutorial/example_step_wise.md index f40250f5e3..2909add310 100644 --- a/docs/sphinx_doc/source_zh/tutorial/example_step_wise.md +++ b/docs/sphinx_doc/source_zh/tutorial/example_step_wise.md @@ -80,7 +80,7 @@ class StepWiseAlfworldWorkflow(RewardPropagationWorkflow): - `buffer.train_batch_size`:从 buffer 中采样用于训练的 experience 数量,可以与每次探索生成的 experience 数量不同。 -- `buffer.trainer_input.use_priority_queue = true`:使用 `PriorityQueue` 可使模型优先使用高优先级的 experience (默认为使用更新产生的 experience)。 +- `buffer.trainer_input.experience_buffer.replay_buffer`:使用 `PriorityQueue` 可使模型优先使用高优先级的 experience (默认为使用更新产生的 experience)。 - `synchronizer.sync_style = dynamic_by_explorer`:由 explorer 决定何时与 trainer 同步模型权重。 @@ -124,7 +124,8 @@ buffer: experience_buffer: name: alfworld_buffer storage_type: queue - use_priority_queue: true + replay_buffer: + enable: true explorer: max_repeat_times_per_runner: 1 runner_per_model: 16 @@ -154,11 +155,12 @@ trainer: ulysses_sequence_parallel_size: 1 ``` - 下面,我们提供运行 ALFWorld 任务的命令。 ## 示例:多步 ALFWorld + ### 环境准备 + 要安装 ALFWorld 环境,可按照以下说明操作。 1. 使用 pip 安装:`pip install alfworld[full]` diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index 998fe939e8..0b18c3eee0 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -336,8 +336,11 @@ buffer: - `enable_concatenated_multi_turn`: 启用拼接的多轮 SFT 数据预处理。仅适用于 `messages`,且仅在 SFT 算法中生效。 - `chat_template`: 以字符串形式指定 chat template。若未提供,则使用 `model.custom_chat_template`。 - `max_read_timeout`: 读取新 experience 数据的最大等待时间(秒)。若超时,则直接返回不完整批次。仅当 `storage_type` 为 `queue` 时生效。默认为 1800 秒(30 分钟)。 - - `use_priority_queue`: 仅当 `storage_type` 为 `queue` 时生效。若设为 `True`,队列为优先级队列,允许优先处理某些 experience。默认为 `False`。 - - `reuse_cooldown_time`: 仅当 `storage_type` 为 `queue` 且 `use_priority_queue` 为 `True` 时生效。若设置,指定 experience 重用的冷却时间(秒)。若未指定,默认为 `None`,表示 experience 不可被重复使用。 + - `replay_buffer`: 仅当 `storage_type` 为 `queue` 时生效。用于配置 experience 重用的回放缓冲区。 + - `enable`: 是否启用回放缓冲区。默认为 `false`。 + - `reuse_cooldown_time`: experience 重用的冷却时间(秒)。若未指定,默认为 `None`,表示 experience 不可被重复使用。 + - `priority_fn`: experience 优先级函数,用于确定 experience 的重用顺序。目前支持 `linear_decay` 和 `linear_decay_use_count_control_randomization`。 + - `priority_fn_args`: 传递给优先级函数的参数字典,具体参数取决于所选的优先级函数。 - `auxiliary_buffers`: trainer 使用的可选缓冲区。为字典结构,每个键为 buffer 名称,值为 buffer 配置。每个 buffer 配置与 `experience_buffer` 类似。 --- diff --git a/examples/agentscope_websearch/agentscopev1_websearch_agent.yaml b/examples/agentscope_websearch/agentscopev1_websearch_agent.yaml index cd9a87594a..5002a26b4a 100644 --- a/examples/agentscope_websearch/agentscopev1_websearch_agent.yaml +++ b/examples/agentscope_websearch/agentscopev1_websearch_agent.yaml @@ -53,7 +53,8 @@ buffer: experience_buffer: name: experience_buffer storage_type: queue - use_priority_queue: true + replay_buffer: + enable: true explorer: eval_interval: 10 max_repeat_times_per_runner: 1 diff --git a/examples/grpo_alfworld_general_multi_step/alfworld.yaml b/examples/grpo_alfworld_general_multi_step/alfworld.yaml index e36016a3b2..5427b6829e 100644 --- a/examples/grpo_alfworld_general_multi_step/alfworld.yaml +++ b/examples/grpo_alfworld_general_multi_step/alfworld.yaml @@ -35,7 +35,8 @@ buffer: experience_buffer: name: alfworld_buffer storage_type: queue - use_priority_queue: true + replay_buffer: + enable: true explorer: max_repeat_times_per_runner: 1 runner_per_model: 8 diff --git a/examples/grpo_email_search/email_search.yaml b/examples/grpo_email_search/email_search.yaml index fa3b96f2a5..c3227b1456 100644 --- a/examples/grpo_email_search/email_search.yaml +++ b/examples/grpo_email_search/email_search.yaml @@ -54,7 +54,8 @@ buffer: experience_buffer: name: experience_buffer storage_type: queue - use_priority_queue: true + replay_buffer: + enable: true explorer: eval_interval: 10 max_repeat_times_per_runner: 1 diff --git a/examples/grpo_rubric_as_reward/rubric.yaml b/examples/grpo_rubric_as_reward/rubric.yaml index 6e66dc348f..48e6909ba0 100644 --- a/examples/grpo_rubric_as_reward/rubric.yaml +++ b/examples/grpo_rubric_as_reward/rubric.yaml @@ -36,7 +36,8 @@ buffer: experience_buffer: name: experience_buffer storage_type: queue - use_priority_queue: true + replay_buffer: + enable: true explorer: eval_interval: 10 max_timeout: 3600 diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py index 1fdfe95701..ef266bbff1 100644 --- a/tests/buffer/queue_test.py +++ b/tests/buffer/queue_test.py @@ -10,7 +10,7 @@ from tests.tools import RayUnittestBaseAysnc from trinity.buffer.reader.queue_reader import QueueReader from trinity.buffer.writer.queue_writer import QueueWriter -from trinity.common.config import ExperienceBufferConfig +from trinity.common.config import ExperienceBufferConfig, ReplayBufferConfig from trinity.common.constants import StorageType from trinity.common.experience import Experience @@ -37,9 +37,9 @@ async def test_queue_buffer(self, name, use_priority_queue): storage_type=StorageType.QUEUE, max_read_timeout=3, path=BUFFER_FILE_PATH, - use_priority_queue=use_priority_queue, batch_size=self.train_batch_size, ) + config.replay_buffer.enable = use_priority_queue config = config.to_storage_config() writer = QueueWriter(config) reader = QueueReader(config) @@ -104,8 +104,12 @@ async def test_priority_queue_capacity(self): max_read_timeout=1, capacity=8, path=BUFFER_FILE_PATH, - use_priority_queue=True, - replay_buffer_kwargs={"priority_fn": "linear_decay", "decay": 0.6}, + replay_buffer=ReplayBufferConfig( + enable=True, + priority_fn="linear_decay", + reuse_cooldown_time=None, + priority_fn_args={"decay": 0.6}, + ), batch_size=self.train_batch_size, ) config = config.to_storage_config() @@ -191,9 +195,12 @@ async def test_priority_queue_buffer_reuse(self): max_read_timeout=3, capacity=4, # max total number of items; each item is List[Experience] path=BUFFER_FILE_PATH, - use_priority_queue=True, - reuse_cooldown_time=0.5, - replay_buffer_kwargs={"priority_fn": "linear_decay", "decay": 0.6}, + replay_buffer=ReplayBufferConfig( + enable=True, + priority_fn="linear_decay", + reuse_cooldown_time=0.5, + priority_fn_args={"decay": 0.6}, + ), batch_size=self.train_batch_size, ) config = config.to_storage_config() @@ -317,14 +324,12 @@ async def test_priority_queue_reuse_count_control(self): max_read_timeout=3, capacity=4, # max total number of items; each item is List[Experience] path=BUFFER_FILE_PATH, - use_priority_queue=True, - reuse_cooldown_time=0.5, - replay_buffer_kwargs={ - "priority_fn": "linear_decay_use_count_control_randomization", - "decay": 1.2, - "use_count_limit": 2, - "sigma": 0.0, - }, + replay_buffer=ReplayBufferConfig( + enable=True, + priority_fn="linear_decay_use_count_control_randomization", + reuse_cooldown_time=0.5, + priority_fn_args={"decay": 1.2, "use_count_limit": 2, "sigma": 0.0}, + ), batch_size=self.train_batch_size, ) config = config.to_storage_config() diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 3eaff088cc..4e973865d9 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -498,8 +498,8 @@ def test_fully_async_mode(self): config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( name="exp_buffer", storage_type=StorageType.QUEUE, - use_priority_queue=self.use_priority_queue, ) + config.buffer.trainer_input.experience_buffer.replay_buffer.enable = self.use_priority_queue config.synchronizer.sync_method = SyncMethod.CHECKPOINT config.synchronizer.sync_style = SyncStyle.DYNAMIC_BY_EXPLORER config.synchronizer.sync_interval = 8 diff --git a/trinity/buffer/storage/queue.py b/trinity/buffer/storage/queue.py index 324122746c..9cd20f73d5 100644 --- a/trinity/buffer/storage/queue.py +++ b/trinity/buffer/storage/queue.py @@ -1,4 +1,5 @@ """Ray Queue storage""" + import asyncio import time from abc import ABC, abstractmethod @@ -96,14 +97,17 @@ def stopped(self) -> bool: def get_queue(cls, config: StorageConfig) -> "QueueBuffer": """Get a queue instance based on the storage configuration.""" logger = get_logger(__name__) - if config.use_priority_queue: - reuse_cooldown_time = config.reuse_cooldown_time - replay_buffer_kwargs = config.replay_buffer_kwargs + if config.replay_buffer.enable: capacity = config.capacity logger.info( - f"Using AsyncPriorityQueue with capacity {capacity}, reuse_cooldown_time {reuse_cooldown_time}." + f"Using AsyncPriorityQueue with capacity {capacity}, reuse_cooldown_time {config.replay_buffer.reuse_cooldown_time}." + ) + return AsyncPriorityQueue( + capacity=capacity, + reuse_cooldown_time=config.replay_buffer.reuse_cooldown_time, + priority_fn=config.replay_buffer.priority_fn, + priority_fn_args=config.replay_buffer.priority_fn_args, ) - return AsyncPriorityQueue(capacity, reuse_cooldown_time, **replay_buffer_kwargs) else: return AsyncQueue(capacity=config.capacity) @@ -140,9 +144,9 @@ class AsyncPriorityQueue(QueueBuffer): Attributes: capacity (int): Maximum number of items the queue can hold. This value is automatically adjusted to be at most twice the read batch size. - priority_groups (SortedDict): Maps priorities to deques of items with the same priority. - priority_fn (callable): Function used to determine the priority of an item. reuse_cooldown_time (float): Delay before reusing an item (set to infinity to disable). + priority_fn (callable): Function used to determine the priority of an item. + priority_groups (SortedDict): Maps priorities to deques of items with the same priority. """ def __init__( @@ -150,7 +154,7 @@ def __init__( capacity: int, reuse_cooldown_time: Optional[float] = None, priority_fn: str = "linear_decay", - **kwargs, + priority_fn_args: Optional[dict] = None, ): """ Initialize the async priority queue. @@ -164,7 +168,7 @@ def __init__( self.capacity = capacity self.item_count = 0 self.priority_groups = SortedDict() # Maps priority -> deque of items - self.priority_fn = partial(PRIORITY_FUNC.get(priority_fn), **kwargs) + self.priority_fn = partial(PRIORITY_FUNC.get(priority_fn), **(priority_fn_args or {})) self.reuse_cooldown_time = reuse_cooldown_time self._condition = asyncio.Condition() # For thread-safe operations self._closed = False diff --git a/trinity/common/config.py b/trinity/common/config.py index 0ac94a7e27..7668a9a0e4 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -115,11 +115,11 @@ class LoRAConfig: @dataclass class ReplayBufferConfig: - """Config for replay buffer used in StorageType.QUEUE with use_priority_queue=True.""" + """Config for replay buffer used in StorageType.QUEUE.""" + enable: bool = False priority_fn: str = "linear_decay" reuse_cooldown_time: Optional[float] = None - priority_fn_args: Dict = field(default_factory=lambda: {"decay": 2.0}) @@ -145,11 +145,7 @@ class StorageConfig: # used for StorageType.QUEUE capacity: int = 10000 max_read_timeout: float = 1800 - use_priority_queue: bool = False - reuse_cooldown_time: Optional[float] = None - replay_buffer_kwargs: dict = field( - default_factory=lambda: {"priority_fn": "linear_decay", "decay": 2.0} - ) + replay_buffer: Optional[ReplayBufferConfig] = field(default_factory=ReplayBufferConfig) # used for StorageType.SQL max_retry_times: int = 3 @@ -263,11 +259,7 @@ class ExperienceBufferConfig: # used for StorageType.QUEUE capacity: int = 10000 max_read_timeout: float = 1800 - use_priority_queue: bool = False - reuse_cooldown_time: Optional[float] = None - replay_buffer_kwargs: dict = field( - default_factory=lambda: {"priority_fn": "linear_decay", "decay": 2.0} - ) + replay_buffer: Optional[ReplayBufferConfig] = field(default_factory=ReplayBufferConfig) # used for StorageType.SQL max_retry_times: int = 3 @@ -277,6 +269,7 @@ class ExperienceBufferConfig: split: str = "train" subset_name: Optional[str] = None format: FormatConfig = field(default_factory=FormatConfig) + enable_progress_bar: Optional[bool] = False # ! DO NOT SET, automatically set schema_type: Optional[str] = None @@ -298,14 +291,13 @@ def to_storage_config(self) -> StorageConfig: path=self.path, capacity=self.capacity, max_read_timeout=self.max_read_timeout, - use_priority_queue=self.use_priority_queue, - reuse_cooldown_time=self.reuse_cooldown_time, - replay_buffer_kwargs=self.replay_buffer_kwargs, + replay_buffer=self.replay_buffer, max_retry_times=self.max_retry_times, max_retry_interval=self.max_retry_interval, split=self.split, subset_name=self.subset_name, format=self.format, + enable_progress_bar=self.enable_progress_bar, schema_type=self.schema_type, index=self.index, batch_size=self.batch_size, From 0018559ffbf7d2440396959ef919f48003bf85e8 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 23 Oct 2025 19:26:59 +0800 Subject: [PATCH 06/21] fix replay buffer config --- benchmark/config/countdown-template.yaml | 3 ++- benchmark/config/gsm8k-template.yaml | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/benchmark/config/countdown-template.yaml b/benchmark/config/countdown-template.yaml index 7e6919dd8e..2bd2e75259 100644 --- a/benchmark/config/countdown-template.yaml +++ b/benchmark/config/countdown-template.yaml @@ -42,7 +42,8 @@ buffer: experience_buffer: name: experience_buffer storage_type: queue - reply_buffer: + replay_buffer: + enable: true priority_fn: linear_decay priority_fn_args: decay: 0.1 diff --git a/benchmark/config/gsm8k-template.yaml b/benchmark/config/gsm8k-template.yaml index feb0ce62e5..59b96250a1 100644 --- a/benchmark/config/gsm8k-template.yaml +++ b/benchmark/config/gsm8k-template.yaml @@ -47,7 +47,8 @@ buffer: experience_buffer: name: experience_buffer storage_type: queue - reply_buffer: + replay_buffer: + enable: true priority_fn: linear_decay priority_fn_args: decay: 0.1 From 2013bf1563c23cfebc0ef6edcce64156bbea480d Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 23 Oct 2025 19:37:39 +0800 Subject: [PATCH 07/21] update doc --- docs/sphinx_doc/source/tutorial/trinity_configs.md | 8 +++----- docs/sphinx_doc/source_zh/tutorial/trinity_configs.md | 8 +++----- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 6bcf4a20e6..ce177cd10d 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -273,14 +273,12 @@ The configuration for each task dataset is defined as follows: - `name`: Name of the dataset. This name will be used as the Ray actor's name, so it must be unique. - `storage_type`: How the dataset is stored. Options: `file`, `queue`, `sql`. - `file`: The dataset is stored in `jsonl`/`parquet` files. The data file organization is required to meet the huggingface standard. *We recommand using this storage type for most cases.* - - `queue`: The dataset is stored in a queue. The queue is a simple FIFO queue that stores the task dataset. *Do not use this storage type for task dataset unless you know what you are doing.* - `sql`: The dataset is stored in a SQL database. *This type is unstable and will be optimized in the future versions.* - `path`: The path to the task dataset. - For `file` storage type, the path points to the directory that contains the task dataset files. - - For `queue` storage type, the path is optional. You can back up the data in the queue by specifying a sqlite database path here. - For `sql` storage type, the path points to the sqlite database file. -- `subset_name`: The subset name of the task dataset. Default is `None`. -- `split`: The split of the task dataset. Default is `train`. +- `subset_name`: The subset name of the task dataset, according to the `name` parameter in huggingface datasets `load_dataset` function. Default is `None`. +- `split`: The split of the task dataset, according to the `split` parameter in huggingface datasets `load_dataset` function. Default is `train`. - `repeat_times`: The number of rollouts generated for a task. If not set, it will be automatically set to `algorithm.repeat_times` for `taskset`, and `1` for `eval_tasksets`. - `rollout_args`: The parameters for rollout. - `temperature`: The temperature for sampling. @@ -324,7 +322,7 @@ buffer: - For `queue` storage type, this field is optional. You can specify a SQLite database or JSON file path here to back up the queue data. - For `file` storage type, the path points to the directory containing the dataset files. - For `sql` storage type, the path points to the SQLite database file. - - `format`: Defines keys for prompts and responses in the dataset. + - `format`: Mainly for SFT and DPO algorithm datasets, used to format the extracted data. - `prompt_type`: Specifies the type of prompts in the dataset. We support `plaintext`, `messages` for now. - `plaintext`: The prompt is in string format. - `messages`: The prompt is organized as a message list. diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index 0b18c3eee0..6398417368 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -270,14 +270,12 @@ buffer: - `name`: 数据集名称。该名称将用作 Ray actor 的名称,因此必须唯一。 - `storage_type`: 数据集的存储方式。选项:`file`、`queue`、`sql`。 - `file`: 数据集存储在 `jsonl`/`parquet` 文件中。数据文件组织需符合 HuggingFace 标准。*建议大多数情况下使用此存储类型。* - - `queue`: 数据集存储在队列中。队列是一个简单的 FIFO 队列,用于存储任务数据集。*除非你明确了解其用途,否则不要为此类数据集使用此类型。* - `sql`: 数据集存储在 SQL 数据库中。*此类型尚不稳定,将在未来版本中优化。* - `path`: 任务数据集的路径。 - 对于 `file` 类型,路径指向包含任务数据集文件的目录。 - - 对于 `queue` 类型,路径为可选。可通过在此指定 sqlite 数据库路径来备份队列数据。 - 对于 `sql` 类型,路径指向 sqlite 数据库文件。 -- `subset_name`: 任务数据集的子集名称。默认为 `None`。 -- `split`: 任务数据集的划分。默认为 `train`。 +- `subset_name`: 任务数据集的子集名称,对应 huggingface datasets `load_dataset` 函数中的 `name` 参数。默认为 `None`。 +- `split`: 任务数据集的划分。对应 huggingface datasets `load_dataset` 函数中的 `split` 参数。默认为 `train`。 - `repeat_times`: 为一个任务生成的 rollout 数量。若未设置,则自动设为 `algorithm.repeat_times`(`taskset`)或 `1`(`eval_tasksets`)。 - `rollout_args`: rollout 参数。 - `temperature`: 采样温度。 @@ -321,7 +319,7 @@ buffer: - 对于 `queue` 类型,此字段可选。可在此指定 SQLite 数据库或 JSON 文件路径以备份队列数据。 - 对于 `file` 类型,路径指向包含数据集文件的目录。 - 对于 `sql` 类型,路径指向 SQLite 数据库文件。 - - `format`: 定义数据集中 prompt 和 response 的键。 + - `format`: 主要针对 SFT 和 DPO 算法的数据集,用于规范化提取的数据。 - `prompt_type`: 指定数据集中 prompt 的类型。目前支持 `plaintext`、`messages`。 - `plaintext`: prompt 为 string 格式。 - `messages`: prompt 为消息列表。 From 4885f1524ab89632152f16586031678f6b759fd3 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 23 Oct 2025 19:43:23 +0800 Subject: [PATCH 08/21] fix pre-commit --- tests/buffer/file_test.py | 3 +-- tests/cli/launcher_test.py | 10 ---------- trinity/buffer/reader/file_reader.py | 3 --- trinity/common/config.py | 2 -- trinity/common/constants.py | 1 - 5 files changed, 1 insertion(+), 18 deletions(-) diff --git a/tests/buffer/file_test.py b/tests/buffer/file_test.py index 06f538f3f1..a56a5371e2 100644 --- a/tests/buffer/file_test.py +++ b/tests/buffer/file_test.py @@ -1,5 +1,4 @@ import os -import shutil import unittest import ray @@ -114,7 +113,7 @@ def setUp(self): os.makedirs(self.config.buffer.cache_dir, exist_ok=True) file_path = self.config.buffer.trainer_input.experience_buffer.path if os.path.exists(file_path): - shutil.rmtree(file_path) + os.remove(file_path) def tearDown(self): ray.shutdown() diff --git a/tests/cli/launcher_test.py b/tests/cli/launcher_test.py index 294bcf50fc..1b8ab142e8 100644 --- a/tests/cli/launcher_test.py +++ b/tests/cli/launcher_test.py @@ -21,7 +21,6 @@ TrainerInput, ) from trinity.common.constants import ( - CHECKPOINT_JOB_DIR_ENV_VAR, LOG_DIR_ENV_VAR, LOG_LEVEL_ENV_VAR, LOG_NODE_IP_ENV_VAR, @@ -119,9 +118,6 @@ def test_main_run_in_dlc(self, mock_init, mock_load, mock_both, mock_setup, mock runtime_env={ "env_vars": { launcher.PLUGIN_DIRS_ENV_VAR: "/path/to/plugins", - CHECKPOINT_JOB_DIR_ENV_VAR: os.path.join( - config.checkpoint_root_dir, config.project, config.name - ), LOG_DIR_ENV_VAR: config.log.save_dir, LOG_LEVEL_ENV_VAR: config.log.level, LOG_NODE_IP_ENV_VAR: "1", @@ -216,9 +212,6 @@ def test_multi_stage_run( runtime_env={ "env_vars": { launcher.PLUGIN_DIRS_ENV_VAR: "/path/to/plugins", - CHECKPOINT_JOB_DIR_ENV_VAR: os.path.join( - config.checkpoint_root_dir, config.project, config.name - ), LOG_DIR_ENV_VAR: os.path.join( config.checkpoint_root_dir, config.project, @@ -237,9 +230,6 @@ def test_multi_stage_run( runtime_env={ "env_vars": { launcher.PLUGIN_DIRS_ENV_VAR: "/path/to/plugins", - CHECKPOINT_JOB_DIR_ENV_VAR: os.path.join( - config.checkpoint_root_dir, config.project, config.name - ), LOG_DIR_ENV_VAR: os.path.join( config.checkpoint_root_dir, config.project, diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 93d2b5d54f..1d6ec89a43 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -154,6 +154,3 @@ def read(self, batch_size: Optional[int] = None) -> List: task = self.formatter.format(sample) tasks.append(task) return tasks - - def __len__(self) -> int: - return len(self.dataset.dataset) diff --git a/trinity/common/config.py b/trinity/common/config.py index 7668a9a0e4..f4f35c02a3 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -14,7 +14,6 @@ from omegaconf import OmegaConf from trinity.common.constants import ( - CHECKPOINT_JOB_DIR_ENV_VAR, EXPLORER_NAME, LOG_DIR_ENV_VAR, LOG_LEVEL_ENV_VAR, @@ -1286,7 +1285,6 @@ def get_envs(self) -> Dict[str, str]: """Get the environment variables from the config.""" return { PLUGIN_DIRS_ENV_VAR: os.getenv(PLUGIN_DIRS_ENV_VAR, ""), - CHECKPOINT_JOB_DIR_ENV_VAR: self.checkpoint_job_dir, LOG_LEVEL_ENV_VAR: self.log.level, LOG_DIR_ENV_VAR: self.log.save_dir, LOG_NODE_IP_ENV_VAR: "1" if self.log.group_by_node else "0", diff --git a/trinity/common/constants.py b/trinity/common/constants.py index 4ba2b505b0..ad092603d2 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -12,7 +12,6 @@ # trinity env var names CHECKPOINT_ROOT_DIR_ENV_VAR = "TRINITY_CHECKPOINT_ROOT_DIR" -CHECKPOINT_JOB_DIR_ENV_VAR = "TRINITY_CHECKPOINT_JOB_DIR" PREVIOUS_STAGE_CHECKPOINT_DIR_ENV_VAR = "TRINITY_PREV_STAGE_CKPT_DIR" MODEL_PATH_ENV_VAR = "TRINITY_MODEL_PATH" TASKSET_PATH_ENV_VAR = "TRINITY_TASKSET_PATH" From 478650a7e4a181f34f045832dd10be00488bd503 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 23 Oct 2025 19:44:48 +0800 Subject: [PATCH 09/21] fix comments --- trinity/buffer/buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trinity/buffer/buffer.py b/trinity/buffer/buffer.py index eb47af2806..46929f06be 100644 --- a/trinity/buffer/buffer.py +++ b/trinity/buffer/buffer.py @@ -59,4 +59,4 @@ def get_buffer_writer(config: BufferStorageConfig) -> BufferWriter: return JSONWriter(storage_config) else: - raise ValueError(f"{storage_config} not supported.") + raise ValueError(f"{storage_config.storage_type} not supported.") From 60efda59658b6735dca23068d9df7dd41e8bf531 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 24 Oct 2025 10:14:25 +0800 Subject: [PATCH 10/21] clean unittest checkpoint dir --- .github/workflows/unittest.yaml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index 04a124d97e..435db3d362 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -97,6 +97,13 @@ jobs: fi fi + - name: Clean checkpoint dir + working-directory: trinity-${{ github.run_id }}/.github/workflows/docker + if: always() + run: | + docker compose exec trinity-node-1 rm -rf /mnt/checkpoints/* + continue-on-error: true + - name: Upload test results if: env.tests_run == 'true' || failure() uses: actions/upload-artifact@v4 From bec3656bc2cec89439cc7481f556aff274a7541a Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 24 Oct 2025 10:29:59 +0800 Subject: [PATCH 11/21] fix manager --- trinity/manager/config_manager.py | 16 +++++++++------- .../config_registry/buffer_config_manager.py | 10 +++++----- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index 31a95d0b76..f8c3996857 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -175,7 +175,7 @@ def _expert_buffer_part(self): self.get_configs("system_prompt") self.get_configs("reply_prefix") - if st.session_state["algorithm_type"] != "dpo": + if st.session_state["algorithm_type"] not in ["dpo", "sft"]: with st.expander("Taskset Configs", expanded=True): self.get_configs("taskset_path") self.get_configs("taskset_args") @@ -188,11 +188,11 @@ def _expert_buffer_part(self): with st.expander("Eval Tasksets Configs", expanded=True): self.get_configs("eval_tasksets") - if st.session_state["algorithm_type"] != "dpo": + if st.session_state["algorithm_type"] not in ["dpo", "sft"]: with st.expander("Experiences Buffer Configs", expanded=True): self.get_configs("storage_type") self.get_configs("experience_buffer_path") - self.get_configs("use_priority_queue") + self.get_configs("enable_replay_buffer") self.get_configs("reuse_cooldown_time", "priority_fn", "priority_decay") # TODO: used for SQL storage @@ -586,11 +586,13 @@ def _gen_buffer_config(self): del buffer_config["train_batch_size"] if st.session_state["algorithm_type"] != "dpo": experience_buffer = buffer_config["trainer_input"]["experience_buffer"] - experience_buffer["use_priority_queue"] = st.session_state["use_priority_queue"] - experience_buffer["reuse_cooldown_time"] = st.session_state["reuse_cooldown_time"] - experience_buffer["replay_buffer_kwargs"] = { + experience_buffer["replay_buffer"] = { + "enable": st.session_state["enable_replay_buffer"], "priority_fn": st.session_state["priority_fn"], - "decay": st.session_state["priority_decay"], + "reuse_cooldown_time": st.session_state["reuse_cooldown_time"], + "priority_fn_args": { + "decay": st.session_state["priority_decay"], + }, } if st.session_state["mode"] != "train": diff --git a/trinity/manager/config_registry/buffer_config_manager.py b/trinity/manager/config_registry/buffer_config_manager.py index de86e77ac4..c44c390cff 100644 --- a/trinity/manager/config_registry/buffer_config_manager.py +++ b/trinity/manager/config_registry/buffer_config_manager.py @@ -314,12 +314,12 @@ def on_change(): @CONFIG_GENERATORS.register_config(default_value=False) -def set_use_priority_queue(**kwargs): - st.checkbox("Use Priority Queue", **kwargs) +def set_enable_replay_buffer(**kwargs): + st.checkbox("Enable Replay Buffer", **kwargs) @CONFIG_GENERATORS.register_config( - default_value=None, visible=lambda: st.session_state["use_priority_queue"] + default_value=None, visible=lambda: st.session_state["enable_replay_buffer"] ) def set_reuse_cooldown_time(**kwargs): st.number_input( @@ -333,7 +333,7 @@ def set_reuse_cooldown_time(**kwargs): @CONFIG_GENERATORS.register_config( - default_value="linear_decay", visible=lambda: st.session_state["use_priority_queue"] + default_value="linear_decay", visible=lambda: st.session_state["enable_replay_buffer"] ) def set_priority_fn(**kwargs): candidates = list(PRIORITY_FUNC.modules.keys()) @@ -345,7 +345,7 @@ def set_priority_fn(**kwargs): @CONFIG_GENERATORS.register_config( - default_value=0.1, visible=lambda: st.session_state["use_priority_queue"] + default_value=0.1, visible=lambda: st.session_state["enable_replay_buffer"] ) def set_priority_decay(**kwargs): st.number_input( From b0966644e2430584ba2863b82b083499a051edb4 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 24 Oct 2025 12:08:20 +0800 Subject: [PATCH 12/21] fix train_batch_size --- .../source/tutorial/example_mix_algo.md | 16 ++++++++-------- .../source_zh/tutorial/example_mix_algo.md | 16 ++++++++-------- .../sample_strategy/mix_sample_strategy.py | 14 ++++++++------ trinity/buffer/reader/file_reader.py | 2 +- 4 files changed, 25 insertions(+), 23 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index 528bbd29a4..f5b4a533ae 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -85,11 +85,9 @@ class MixSampleStrategy(SampleStrategy): expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size) # experience buffer - usual_buffer_config = copy.deepcopy(buffer_config) - usual_buffer_config.train_batch_size = tot_batch_size - expert_batch_size - self.usual_exp_buffer = get_buffer_reader( - buffer_config.trainer_input.experience_buffer - ) + usual_buffer_config = copy.deepcopy(buffer_config.trainer_input.experience_buffer) + usual_buffer_config.batch_size = tot_batch_size - expert_batch_size + self.usual_exp_buffer = get_buffer_reader(usual_buffer_config) if buffer_config.trainer_input.auxiliary_buffers is None: raise ValueError( @@ -97,11 +95,13 @@ class MixSampleStrategy(SampleStrategy): ) # expert experience buffer - expert_buffer_config = copy.deepcopy(buffer_config) - expert_buffer_config.train_batch_size = expert_batch_size - self.expert_exp_buffer = get_buffer_reader( + expert_buffer_config = copy.deepcopy( buffer_config.trainer_input.auxiliary_buffers[self.sft_dataset_name] ) + expert_buffer_config.batch_size = expert_batch_size + self.expert_exp_buffer = get_buffer_reader( + expert_buffer_config, + ) async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: metrics = {} diff --git a/docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md b/docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md index 070152cf72..6c857e70aa 100644 --- a/docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md @@ -77,11 +77,9 @@ class MixSampleStrategy(SampleStrategy): expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size) # experience buffer - usual_buffer_config = copy.deepcopy(buffer_config) - usual_buffer_config.train_batch_size = tot_batch_size - expert_batch_size - self.usual_exp_buffer = get_buffer_reader( - buffer_config.trainer_input.experience_buffer - ) + usual_buffer_config = copy.deepcopy(buffer_config.trainer_input.experience_buffer) + usual_buffer_config.batch_size = tot_batch_size - expert_batch_size + self.usual_exp_buffer = get_buffer_reader(usual_buffer_config) if buffer_config.trainer_input.auxiliary_buffers is None: raise ValueError( @@ -89,11 +87,13 @@ class MixSampleStrategy(SampleStrategy): ) # expert experience buffer - expert_buffer_config = copy.deepcopy(buffer_config) - expert_buffer_config.train_batch_size = expert_batch_size - self.expert_exp_buffer = get_buffer_reader( + expert_buffer_config = copy.deepcopy( buffer_config.trainer_input.auxiliary_buffers[self.sft_dataset_name] ) + expert_buffer_config.batch_size = expert_batch_size + self.expert_exp_buffer = get_buffer_reader( + expert_buffer_config, + ) async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: metrics = {} diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index 52f2ef0199..054e77ae6b 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -27,9 +27,9 @@ def __init__(self, buffer_config: BufferConfig, **kwargs): expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size) # experience buffer - usual_buffer_config = copy.deepcopy(buffer_config) - usual_buffer_config.train_batch_size = tot_batch_size - expert_batch_size - self.usual_exp_buffer = get_buffer_reader(buffer_config.trainer_input.experience_buffer) # type: ignore[arg-type] + usual_buffer_config = copy.deepcopy(buffer_config.trainer_input.experience_buffer) + usual_buffer_config.batch_size = tot_batch_size - expert_batch_size + self.usual_exp_buffer = get_buffer_reader(usual_buffer_config) # type: ignore[arg-type] if buffer_config.trainer_input.auxiliary_buffers is None: raise ValueError( @@ -49,10 +49,12 @@ def __init__(self, buffer_config: BufferConfig, **kwargs): expert_storage_config.schema_type = "sft" # expert experience buffer - expert_buffer_config = copy.deepcopy(buffer_config) - expert_buffer_config.train_batch_size = expert_batch_size + expert_buffer_config = copy.deepcopy( + buffer_config.trainer_input.auxiliary_buffers[self.sft_dataset_name] + ) + expert_buffer_config.batch_size = expert_batch_size self.expert_exp_buffer = get_buffer_reader( - buffer_config.trainer_input.auxiliary_buffers[self.sft_dataset_name], + expert_buffer_config, ) async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 1d6ec89a43..edc5577dac 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -105,7 +105,7 @@ def __init__(self, config: StorageConfig): self.formatter = FORMATTER.get(config.schema_type)( tokenizer_path=config.tokenizer_path, format_config=config.format ) - self.read_batch_size = config.train_batch_size + self.read_batch_size = config.batch_size self.dataset = _HFBatchReader( load_dataset(config.path, name=config.subset_name, split=config.split), name=config.name, From aa3cc7d9717d7aa26e429c919ded7928c7b20907 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 24 Oct 2025 13:44:26 +0800 Subject: [PATCH 13/21] fix explorer test --- tests/explorer/explorer_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 565af64f22..de751e42d9 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -244,7 +244,6 @@ async def test_serve(self): # noqa: C901 self.config.buffer.trainer_input.experience_buffer.max_read_timeout = 5 buffer_reader = get_buffer_reader( self.config.buffer.trainer_input.experience_buffer, - self.config.buffer, ) exps = await buffer_reader.read_async(batch_size=10) for exp in exps: From 83073434d9eecb38efb1279b536c032290ec44d9 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 24 Oct 2025 16:14:29 +0800 Subject: [PATCH 14/21] fix tests --- tests/buffer/file_test.py | 2 +- tests/buffer/task_scheduler_test.py | 6 +++--- trinity/common/config.py | 13 +++++++++---- trinity/common/verl_config.py | 2 +- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/tests/buffer/file_test.py b/tests/buffer/file_test.py index a56a5371e2..0676c4b100 100644 --- a/tests/buffer/file_test.py +++ b/tests/buffer/file_test.py @@ -16,7 +16,7 @@ class TestFileBuffer(unittest.IsolatedAsyncioTestCase): def test_file_reader(self): # noqa: C901 """Test file reader.""" - reader = get_buffer_reader(self.config.buffer.explorer_input.taskset) + reader = get_buffer_reader(self.config.buffer.explorer_input.tasksets[0]) tasks = [] while True: diff --git a/tests/buffer/task_scheduler_test.py b/tests/buffer/task_scheduler_test.py index ebaf6ed697..ee25f5856e 100644 --- a/tests/buffer/task_scheduler_test.py +++ b/tests/buffer/task_scheduler_test.py @@ -7,7 +7,7 @@ from tests.tools import get_template_config from trinity.buffer.task_scheduler import TasksetScheduler -from trinity.common.config import FormatConfig, StorageConfig, TaskSelectorConfig +from trinity.common.config import FormatConfig, TaskSelectorConfig, TasksetConfig from trinity.common.workflows.workflow import Task @@ -193,7 +193,7 @@ async def test_task_scheduler(self, task_selector_kwargs, batch_tasks_orders) -> config.buffer.total_epochs = 2 config.buffer.explorer_input.taskset = None config.buffer.explorer_input.tasksets = [ - StorageConfig( + TasksetConfig( name="subset_1", path=os.path.join( os.path.dirname(__file__), @@ -215,7 +215,7 @@ async def test_task_scheduler(self, task_selector_kwargs, batch_tasks_orders) -> **task_selector_kwargs, ), ), - StorageConfig( + TasksetConfig( name="subset_2", path=os.path.join( os.path.dirname(__file__), diff --git a/trinity/common/config.py b/trinity/common/config.py index 371c676ee7..a466af29ad 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -212,6 +212,7 @@ class TasksetConfig: rollout_args: GenerationConfig = field(default_factory=GenerationConfig) workflow_args: dict = field(default_factory=dict) reward_fn_args: dict = field(default_factory=dict) + task_selector: TaskSelectorConfig = field(default_factory=TaskSelectorConfig) # used for StorageType.FILE split: str = "train" @@ -242,6 +243,7 @@ def to_storage_config(self) -> StorageConfig: name=self.name, storage_type=self.storage_type, path=self.path, + task_selector=self.task_selector, repeat_times=self.repeat_times, split=self.split, subset_name=self.subset_name, @@ -533,7 +535,7 @@ class ClusterConfig: class ExplorerInput: """Config for explorer input.""" - taskset: TasksetConfig = None + taskset: Optional[TasksetConfig] = None tasksets: List[TasksetConfig] = field(default_factory=list) eval_tasksets: List[TasksetConfig] = field(default_factory=list) # The following args provide default values for the corresponding args in `taskset` and `eval_tasksets` @@ -817,6 +819,10 @@ def _check_interval(self) -> None: ) def _check_explorer_input(self) -> None: + if self.mode == "train": + # no need to check explorer_input in train mode + return + explorer_input = self.buffer.explorer_input if explorer_input.taskset: @@ -824,9 +830,8 @@ def _check_explorer_input(self) -> None: raise ValueError("Do not support setting `taskset` and `tasksets` simultaneously!") explorer_input.tasksets = [explorer_input.taskset] explorer_input.taskset = None - else: - if len(explorer_input.tasksets) == 0: - explorer_input.tasksets = [StorageConfig()] + elif len(explorer_input.tasksets) == 0: + raise ValueError("At least one taskset should be provided in explorer_input!") tasksets = explorer_input.tasksets for i, taskset in enumerate(tasksets): diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 1954833191..5d7a8247e6 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -415,7 +415,7 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 self.actor_rollout_ref.actor.ppo_mini_batch_size = config.buffer.train_batch_size self.actor_rollout_ref.rollout.temperature = config.buffer.explorer_input.tasksets[ 0 - ].rollout_args.temperature + ].rollout_args.temperature if config.buffer.explorer_input.tasksets else 1.0 self.actor_rollout_ref.rollout.n = config.algorithm.repeat_times if self.actor_rollout_ref.actor.grad_clip is None: self.actor_rollout_ref.actor.grad_clip = config.trainer.grad_clip From 4a313339eddd10b9b2408898303c83edd18e712e Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 24 Oct 2025 16:23:39 +0800 Subject: [PATCH 15/21] fix buffer tests --- tests/buffer/file_test.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/buffer/file_test.py b/tests/buffer/file_test.py index 0676c4b100..d4d3fe1c04 100644 --- a/tests/buffer/file_test.py +++ b/tests/buffer/file_test.py @@ -27,10 +27,10 @@ def test_file_reader(self): # noqa: C901 self.assertEqual(len(tasks), 16) # test epoch and offset - self.config.buffer.explorer_input.taskset.total_epochs = 2 - self.config.buffer.explorer_input.taskset.index = 4 + self.config.buffer.explorer_input.tasksets[0].total_epochs = 2 + self.config.buffer.explorer_input.tasksets[0].index = 4 reader = get_buffer_reader( - self.config.buffer.explorer_input.taskset, + self.config.buffer.explorer_input.tasksets[0], ) tasks = [] while True: @@ -41,9 +41,9 @@ def test_file_reader(self): # noqa: C901 self.assertEqual(len(tasks), 16 * 2 - 4) # test total steps and offset - self.config.buffer.explorer_input.taskset.total_steps = 5 - self.config.buffer.explorer_input.taskset.index = 8 - reader = get_buffer_reader(self.config.buffer.explorer_input.taskset) + self.config.buffer.explorer_input.tasksets[0].total_steps = 5 + self.config.buffer.explorer_input.tasksets[0].index = 8 + reader = get_buffer_reader(self.config.buffer.explorer_input.tasksets[0]) tasks = [] while True: try: @@ -53,10 +53,10 @@ def test_file_reader(self): # noqa: C901 self.assertEqual(len(tasks), 20 - 8) # test offset > dataset_len with total_epoch - self.config.buffer.explorer_input.taskset.total_steps = None - self.config.buffer.explorer_input.taskset.total_epochs = 3 - self.config.buffer.explorer_input.taskset.index = 20 - reader = get_buffer_reader(self.config.buffer.explorer_input.taskset) + self.config.buffer.explorer_input.tasksets[0].total_steps = None + self.config.buffer.explorer_input.tasksets[0].total_epochs = 3 + self.config.buffer.explorer_input.tasksets[0].index = 20 + reader = get_buffer_reader(self.config.buffer.explorer_input.tasksets[0]) tasks = [] while True: try: @@ -66,9 +66,9 @@ def test_file_reader(self): # noqa: C901 self.assertEqual(len(tasks), 16 * 3 - 20) # test offset > dataset_len with total_steps - self.config.buffer.explorer_input.taskset.total_steps = 10 - self.config.buffer.explorer_input.taskset.index = 24 - reader = get_buffer_reader(self.config.buffer.explorer_input.taskset) + self.config.buffer.explorer_input.tasksets[0].total_steps = 10 + self.config.buffer.explorer_input.tasksets[0].index = 24 + reader = get_buffer_reader(self.config.buffer.explorer_input.tasksets[0]) tasks = [] while True: try: From 11fc8f74c12f43089da4d762064a93d6ba30bf1a Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 24 Oct 2025 16:31:16 +0800 Subject: [PATCH 16/21] fix pre-commit --- trinity/common/verl_config.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 5d7a8247e6..ff017340df 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -413,9 +413,11 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 self.actor_rollout_ref.model.custom_chat_template = config.model.custom_chat_template self.actor_rollout_ref.actor.optim.total_training_steps = self.trainer.total_training_steps self.actor_rollout_ref.actor.ppo_mini_batch_size = config.buffer.train_batch_size - self.actor_rollout_ref.rollout.temperature = config.buffer.explorer_input.tasksets[ - 0 - ].rollout_args.temperature if config.buffer.explorer_input.tasksets else 1.0 + self.actor_rollout_ref.rollout.temperature = ( + config.buffer.explorer_input.tasksets[0].rollout_args.temperature + if config.buffer.explorer_input.tasksets + else 1.0 + ) self.actor_rollout_ref.rollout.n = config.algorithm.repeat_times if self.actor_rollout_ref.actor.grad_clip is None: self.actor_rollout_ref.actor.grad_clip = config.trainer.grad_clip From e2e9990a14993953d481ff41140229488d34c8cb Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 24 Oct 2025 16:51:49 +0800 Subject: [PATCH 17/21] fix config --- trinity/manager/synchronizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trinity/manager/synchronizer.py b/trinity/manager/synchronizer.py index 415ebda945..94d11ae7f8 100644 --- a/trinity/manager/synchronizer.py +++ b/trinity/manager/synchronizer.py @@ -77,7 +77,7 @@ async def _check_modules(self) -> None: async def _find_latest_state_dict(self) -> None: assert self.config.trainer.trainer_type == "verl" - default_local_dir = self.config.trainer.trainer_config.trainer.default_local_dir + default_local_dir = self.config.checkpoint_job_dir local_latest_state_dict_iteration = os.path.join( default_local_dir, "latest_state_dict_iteration.txt" ) From 9b07610ad72b535ffa4c97069dec80df88392080 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 24 Oct 2025 17:32:28 +0800 Subject: [PATCH 18/21] fix doc --- docs/sphinx_doc/source/tutorial/develop_operator.md | 8 ++++---- docs/sphinx_doc/source/tutorial/develop_selector.md | 2 +- docs/sphinx_doc/source_zh/tutorial/develop_operator.md | 8 ++++---- docs/sphinx_doc/source_zh/tutorial/develop_selector.md | 2 +- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/develop_operator.md b/docs/sphinx_doc/source/tutorial/develop_operator.md index f05b1dc62d..394b3fb175 100644 --- a/docs/sphinx_doc/source/tutorial/develop_operator.md +++ b/docs/sphinx_doc/source/tutorial/develop_operator.md @@ -6,9 +6,9 @@ In Trinity-RFT, the operator module is responsible for processing experience data in the buffer module. It supports existing data processing capabilities from [Data-Juicer](https://github.com/modelscope/data-juicer) naturally, and allows developers to implement their own operators as well. By customizing operators, developers can implement various data processing functionalities, such as data augmentation, filtering, and transformation. You can even implement advantages/returns calculation as operators, as shown in {ref}`Algorithms ` section. -- **DataJuicerOperator** ({class}`trinity.data.operators.DataJuicerOperator`): The operator that wraps the data processing operators from Data-Juicer. It provides a simple interface for developers to list the Data-Juicer operators they want to use. The full list of Data-Juicer operators can be found [here](https://modelscope.github.io/data-juicer/en/main/docs/Operators.html). -- **ExperienceOperator** ({class}`trinity.data.operators.ExperienceOperator`): The base class for all operators used in experience data processing. It defines the interface and common functionalities that all operators should have. Each operator processes a batch of experience data and returns the processed data with metrics for logging. -- **ExperiencePipeline** ({class}`trinity.data.pipelines.ExperiencePipeline`): The experience data processing pipeline that manages a sequence of operators. It takes raw experiences from the `Explorer`, passes them through each operator in the pipeline, and writes the final processed experiences into the input buffer of the `Trainer`. +- **DataJuicerOperator** ({class}`trinity.buffer.operators.DataJuicerOperator`): The operator that wraps the data processing operators from Data-Juicer. It provides a simple interface for developers to list the Data-Juicer operators they want to use. The full list of Data-Juicer operators can be found [here](https://modelscope.github.io/data-juicer/en/main/docs/Operators.html). +- **ExperienceOperator** ({class}`trinity.buffer.operators.ExperienceOperator`): The base class for all operators used in experience data processing. It defines the interface and common functionalities that all operators should have. Each operator processes a batch of experience data and returns the processed data with metrics for logging. +- **ExperiencePipeline** ({class}`trinity.buffer.pipelines.ExperiencePipeline`): The experience data processing pipeline that manages a sequence of operators. It takes raw experiences from the `Explorer`, passes them through each operator in the pipeline, and writes the final processed experiences into the input buffer of the `Trainer`. ```{note} Except for `ExperiencePipeline`, Trinity-RFT also provides `TaskPipeline` for task data processing. @@ -56,7 +56,7 @@ class RewardFilter(ExperienceOperator): return filtered_exps, metrics ``` -After implementation, you need to register this module through {class}`trinity.data.operators.EXPERIENCE_OPERATORS`. Once registered, the module can be configured in the configuration file using the registered name. +After implementation, you need to register this module through {class}`trinity.buffer.operators.EXPERIENCE_OPERATORS`. Once registered, the module can be configured in the configuration file using the registered name. ### Step 2: Use Your Operator diff --git a/docs/sphinx_doc/source/tutorial/develop_selector.md b/docs/sphinx_doc/source/tutorial/develop_selector.md index d7593ce036..5e519df529 100644 --- a/docs/sphinx_doc/source/tutorial/develop_selector.md +++ b/docs/sphinx_doc/source/tutorial/develop_selector.md @@ -1,5 +1,5 @@ -# 🧪 Experimental: Task Selection & Scheduling System +# Selector Development Guide ```{note} This module is currently in **experimental status**. Interfaces may change in future versions. diff --git a/docs/sphinx_doc/source_zh/tutorial/develop_operator.md b/docs/sphinx_doc/source_zh/tutorial/develop_operator.md index 692e6432f2..bb95b45f87 100644 --- a/docs/sphinx_doc/source_zh/tutorial/develop_operator.md +++ b/docs/sphinx_doc/source_zh/tutorial/develop_operator.md @@ -7,9 +7,9 @@ Operator 模块负责处理由 Explorer 所生成的轨迹数据(我们称之为 `Experience`)。它原生支持来自 [Data-Juicer](https://github.com/modelscope/data-juicer) 的数据处理功能,也允许开发者实现自己的算子。 通过自定义数据处理算子,开发者可以实现各种数据处理功能,如数据增强、过滤和转换。你甚至可以将优势值/回报值计算实现为 Operator,如 {ref}`算法 ` 部分所示。 -- **DataJuicerOperator** ({class}`trinity.data.operators.DataJuicerOperator`):封装后的 Data-Juicer 算子,使用时只需在配置文件中标明想要使用的 Data-Juicer 算子列表即可。完整的 Data-Juicer 算子列表请见 [此处](https://modelscope.github.io/data-juicer/en/main/docs/Operators.html)。 -- **ExperienceOperator** ({class}`trinity.data.operators.ExperienceOperator`):用于 experience 数据处理的所有数据处理算子的基类。定义了所有数据处理算子应具备的接口和通用功能。每个算子处理一批 experience 数据,并返回处理后的数据及用于日志记录的指标。 -- **ExperiencePipeline** ({class}`trinity.data.pipelines.ExperiencePipeline`):管理一系列数据处理算子的 experience 数据处理流水线。它从 `Explorer` 获取原始 experience,通过流水线中的每个算子处理,最后将最终处理过的 experience 写入 `Trainer` 的输入缓冲区。 +- **DataJuicerOperator** ({class}`trinity.buffer.operators.DataJuicerOperator`):封装后的 Data-Juicer 算子,使用时只需在配置文件中标明想要使用的 Data-Juicer 算子列表即可。完整的 Data-Juicer 算子列表请见 [此处](https://modelscope.github.io/data-juicer/en/main/docs/Operators.html)。 +- **ExperienceOperator** ({class}`trinity.buffer.operators.ExperienceOperator`):用于 experience 数据处理的所有数据处理算子的基类。定义了所有数据处理算子应具备的接口和通用功能。每个算子处理一批 experience 数据,并返回处理后的数据及用于日志记录的指标。 +- **ExperiencePipeline** ({class}`trinity.buffer.pipelines.ExperiencePipeline`):管理一系列数据处理算子的 experience 数据处理流水线。它从 `Explorer` 获取原始 experience,通过流水线中的每个算子处理,最后将最终处理过的 experience 写入 `Trainer` 的输入缓冲区。 ```{note} 除了 `ExperiencePipeline`,Trinity-RFT 还提供 `TaskPipeline` 用于任务数据处理。 @@ -57,7 +57,7 @@ class RewardFilter(ExperienceOperator): return filtered_exps, metrics ``` -实现后,你需要通过 {class}`trinity.data.operators.EXPERIENCE_OPERATORS` 注册此模块。注册后,该模块可在配置文件中使用注册名称进行配置。 +实现后,你需要通过 {class}`trinity.buffer.operators.EXPERIENCE_OPERATORS` 注册此模块。注册后,该模块可在配置文件中使用注册名称进行配置。 ### 步骤 2:使用此算子 diff --git a/docs/sphinx_doc/source_zh/tutorial/develop_selector.md b/docs/sphinx_doc/source_zh/tutorial/develop_selector.md index 872e3819c4..ab565ab61f 100644 --- a/docs/sphinx_doc/source_zh/tutorial/develop_selector.md +++ b/docs/sphinx_doc/source_zh/tutorial/develop_selector.md @@ -1,4 +1,4 @@ -# 🧪 实验性功能:任务选择与调度系统 +# Selector 开发指南 ```{note} 该模块目前处于 **实验阶段**,接口可能在后续版本中发生变化。 From 6a878c7b767c4b50c339b4943c3365b8a7d2ab7e Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 24 Oct 2025 18:02:00 +0800 Subject: [PATCH 19/21] fix doc --- docs/sphinx_doc/source/tutorial/trinity_configs.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index ce177cd10d..a1c7c4cea4 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -277,8 +277,8 @@ The configuration for each task dataset is defined as follows: - `path`: The path to the task dataset. - For `file` storage type, the path points to the directory that contains the task dataset files. - For `sql` storage type, the path points to the sqlite database file. -- `subset_name`: The subset name of the task dataset, according to the `name` parameter in huggingface datasets `load_dataset` function. Default is `None`. -- `split`: The split of the task dataset, according to the `split` parameter in huggingface datasets `load_dataset` function. Default is `train`. +- `subset_name`: The subset name of the task dataset, corresponding to the `name` parameter in huggingface datasets `load_dataset` function. Default is `None`. +- `split`: The split of the task dataset, corresponding to the `split` parameter in huggingface datasets `load_dataset` function. Default is `train`. - `repeat_times`: The number of rollouts generated for a task. If not set, it will be automatically set to `algorithm.repeat_times` for `taskset`, and `1` for `eval_tasksets`. - `rollout_args`: The parameters for rollout. - `temperature`: The temperature for sampling. From c4f59df7845de533e03429833fce5f23b6b4f5d5 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 24 Oct 2025 18:10:14 +0800 Subject: [PATCH 20/21] fix doc --- docs/sphinx_doc/source/tutorial/develop_selector.md | 3 +-- docs/sphinx_doc/source_zh/tutorial/develop_selector.md | 2 +- docs/sphinx_doc/source_zh/tutorial/trinity_configs.md | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/develop_selector.md b/docs/sphinx_doc/source/tutorial/develop_selector.md index 5e519df529..c84f8f4267 100644 --- a/docs/sphinx_doc/source/tutorial/develop_selector.md +++ b/docs/sphinx_doc/source/tutorial/develop_selector.md @@ -1,5 +1,4 @@ - -# Selector Development Guide +# 🧪 Experimental: Task Selection & Scheduling System ```{note} This module is currently in **experimental status**. Interfaces may change in future versions. diff --git a/docs/sphinx_doc/source_zh/tutorial/develop_selector.md b/docs/sphinx_doc/source_zh/tutorial/develop_selector.md index ab565ab61f..872e3819c4 100644 --- a/docs/sphinx_doc/source_zh/tutorial/develop_selector.md +++ b/docs/sphinx_doc/source_zh/tutorial/develop_selector.md @@ -1,4 +1,4 @@ -# Selector 开发指南 +# 🧪 实验性功能:任务选择与调度系统 ```{note} 该模块目前处于 **实验阶段**,接口可能在后续版本中发生变化。 diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index 6398417368..1866432ff0 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -335,7 +335,7 @@ buffer: - `chat_template`: 以字符串形式指定 chat template。若未提供,则使用 `model.custom_chat_template`。 - `max_read_timeout`: 读取新 experience 数据的最大等待时间(秒)。若超时,则直接返回不完整批次。仅当 `storage_type` 为 `queue` 时生效。默认为 1800 秒(30 分钟)。 - `replay_buffer`: 仅当 `storage_type` 为 `queue` 时生效。用于配置 experience 重用的回放缓冲区。 - - `enable`: 是否启用回放缓冲区。默认为 `false`。 + - `enable`: 是否将 experience 放回缓冲区。默认为 `false`。 - `reuse_cooldown_time`: experience 重用的冷却时间(秒)。若未指定,默认为 `None`,表示 experience 不可被重复使用。 - `priority_fn`: experience 优先级函数,用于确定 experience 的重用顺序。目前支持 `linear_decay` 和 `linear_decay_use_count_control_randomization`。 - `priority_fn_args`: 传递给优先级函数的参数字典,具体参数取决于所选的优先级函数。 From 784f29663a17a8aa8c6b92ac5f3afec904e9b156 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 27 Oct 2025 10:20:04 +0800 Subject: [PATCH 21/21] remove dup --- trinity/common/config.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/trinity/common/config.py b/trinity/common/config.py index a466af29ad..19889477e8 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -870,10 +870,6 @@ def _check_explorer_input(self) -> None: set_if_none(dataset, "ray_namespace", self.ray_namespace) set_if_none(dataset.rollout_args, "max_tokens", self.model.max_response_tokens) - if self.mode != "train": - taskset.total_epochs = self.buffer.total_epochs - taskset.total_steps = self.buffer.total_steps - def _check_trainer_input(self) -> None: trainer_input = self.buffer.trainer_input experience_buffer = trainer_input.experience_buffer