diff --git a/tests/buffer/experience_storage_test.py b/tests/buffer/experience_storage_test.py index b3d225e122..8ba591e13f 100644 --- a/tests/buffer/experience_storage_test.py +++ b/tests/buffer/experience_storage_test.py @@ -5,13 +5,14 @@ import time import torch +from parameterized import parameterized from tests.tools import RayUnittestBaseAysnc from trinity.buffer.reader.sql_reader import SQLReader from trinity.buffer.writer.sql_writer import SQLWriter from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import StorageType -from trinity.common.experience import Experience +from trinity.common.experience import EID, Experience DB_PATH = os.path.join(os.path.dirname(__file__), "test.db") @@ -28,10 +29,11 @@ def setUp(self): if os.path.exists(DB_PATH): os.remove(DB_PATH) - async def test_sql_storage(self): + @parameterized.expand([("sft",), ("dpo",)]) + async def test_sql_storage(self, schema_type): meta = StorageConfig( name="test_storage", - schema_type="experience", + schema_type=schema_type, storage_type=StorageType.SQL, max_read_timeout=3, path=f"sqlite:///{DB_PATH}", @@ -49,8 +51,6 @@ async def test_sql_storage(self): ) for i in range(1, self.put_batch_size + 1) ] - for exp in exps: - exp.info = {"model_version": 0, "use_count": 0} for _ in range(self.total_num // self.put_batch_size): await writer.write_async(exps) for _ in range(self.total_num // self.train_batch_size): @@ -88,3 +88,49 @@ def thread_read(reader, result_queue): value = cursor.execute("SELECT COUNT(*) FROM test_storage;").fetchall() self.assertEqual(value[0][0], self.total_num + self.put_batch_size * 2) self.assertRaises(StopIteration, reader.read, batch_size=1) + + async def test_sql_experience_buffer(self): + meta = StorageConfig( + name="test_storage", + schema_type="experience", + storage_type=StorageType.SQL, + max_read_timeout=3, + path=f"sqlite:///{DB_PATH}", + ) + writer = SQLWriter(meta, self.config) + reader = SQLReader(meta, self.config) + self.assertEqual(await writer.acquire(), 1) + for idx in range(self.total_num // self.put_batch_size): + exps = [ + Experience( + eid=EID(task=idx * self.put_batch_size + i), + tokens=torch.tensor([float(j) for j in range(i + 1)]), + prompt_length=i, + reward=float(i), + logprobs=torch.tensor([0.1]), + ) + for i in range(1, self.put_batch_size + 1) + ] + await writer.write_async(exps) + cnt = self.total_num + for _ in range(self.total_num // self.train_batch_size): + exps = reader.read() + self.assertEqual(len(exps), self.train_batch_size) + for exp in exps: + self.assertEqual(exp.eid.task, cnt) + cnt -= 1 + + # experience buffer support experience reuse + cnt = self.total_num + for _ in range(self.total_num // self.train_batch_size): + exps = reader.read() + self.assertEqual(len(exps), self.train_batch_size) + for exp in exps: + self.assertEqual(exp.eid.task, cnt) + cnt -= 1 + self.assertEqual(await writer.release(), 0) + + conn = sqlite3.connect(DB_PATH) + cursor = conn.cursor() + value = cursor.execute("SELECT COUNT(*) FROM test_storage;").fetchall() + self.assertEqual(value[0][0], self.total_num) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 89bf738788..7710a3e593 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -30,7 +30,7 @@ SyncStyle, ) from trinity.common.models.utils import get_checkpoint_dir_with_step_num -from trinity.manager.manager import CacheManager +from trinity.manager.state_manager import StateManager class BaseTrainerCase(RayUnittestBase): @@ -266,7 +266,7 @@ def test_trainer(self): self.config.buffer.trainer_input.experience_buffer = StorageConfig( name="test_sql_storage", max_read_timeout=20, - storage_type=StorageType.SQL, + storage_type=StorageType.QUEUE, max_retry_times=10, ) self.config.check_and_update() @@ -516,10 +516,10 @@ def test_fully_async_mode(self): rollout_metrics = parser.metric_list("rollout") self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) # check the checkpoint - explorer1_cache = CacheManager(explorer1_config) + explorer1_cache = StateManager(explorer1_config) cache = explorer1_cache.load_explorer() self.assertEqual(cache["latest_iteration"], 4) - explorer2_cache = CacheManager(explorer2_config) + explorer2_cache = StateManager(explorer2_config) cache = explorer2_cache.load_explorer() self.assertEqual(cache["latest_iteration"], 4) # check the lastest checkpoint diff --git a/trinity/buffer/pipelines/experience_pipeline.py b/trinity/buffer/pipelines/experience_pipeline.py index efc1f16ad9..00a2ecb4b0 100644 --- a/trinity/buffer/pipelines/experience_pipeline.py +++ b/trinity/buffer/pipelines/experience_pipeline.py @@ -135,6 +135,9 @@ async def process(self, exps: List[Experience]) -> Dict: return result_metrics async def close(self) -> None: - await self.output.release() + try: + await self.output.release() + except Exception as e: + self.logger.error(f"Failed to release output buffer: {e}") for operator in self.operators: operator.close() diff --git a/trinity/buffer/schema/sql_schema.py b/trinity/buffer/schema/sql_schema.py index 74a9135b7f..e0df0b1e8e 100644 --- a/trinity/buffer/schema/sql_schema.py +++ b/trinity/buffer/schema/sql_schema.py @@ -45,6 +45,7 @@ class ExperienceModel(Base): # type: ignore reward = Column(Float, nullable=True) # serialized experience object experience_bytes = Column(LargeBinary, nullable=True) + consumed = Column(Integer, default=0, index=True) def to_experience(self) -> Experience: """Load the experience from the database.""" diff --git a/trinity/buffer/storage/sql.py b/trinity/buffer/storage/sql.py index cbbdce4c53..4cbe083220 100644 --- a/trinity/buffer/storage/sql.py +++ b/trinity/buffer/storage/sql.py @@ -6,7 +6,7 @@ import ray from datasets import Dataset -from sqlalchemy import asc +from sqlalchemy import asc, desc from sqlalchemy.orm import sessionmaker from trinity.buffer.schema import init_engine @@ -88,29 +88,33 @@ def release(self) -> int: class SQLExperienceStorage(SQLStorage): + """Used as trainer input.""" + def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: super().__init__(storage_config, config) self.batch_size = config.train_batch_size self.max_timeout = storage_config.max_read_timeout + # TODO: optimize the following logic + if storage_config.schema_type == "experience": + # NOTE: consistent with the old version of experience buffer + self._read_method = self._read_priority + else: + # SFT / DPO uses FIFO style + self._read_method = self._read_fifo def write(self, data: List[Experience]) -> None: with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session: experience_models = [self.table_model_cls.from_experience(exp) for exp in data] session.add_all(experience_models) + self.logger.info(f"Write {len(experience_models)} experiences to SQL storage.") - def read(self, batch_size: Optional[int] = None) -> List[Experience]: - if self.stopped: - raise StopIteration() - + def _read_fifo(self, batch_size: int) -> List[Experience]: + """Read experiences in FIFO order.""" exp_list = [] - batch_size = batch_size or self.batch_size # type: ignore start_time = time.time() while len(exp_list) < batch_size: if self.stopped: raise StopIteration() - if len(exp_list): - self.logger.info(f"Waiting for {batch_size - len(exp_list)} more experiences...") - time.sleep(1) if time.time() - start_time > self.max_timeout: self.logger.warning( f"Max read timeout reached ({self.max_timeout} s), only get {len(exp_list)} experiences, stopping..." @@ -131,8 +135,61 @@ def read(self, batch_size: Optional[int] = None) -> List[Experience]: self.offset = experiences[-1].id start_time = time.time() exp_list.extend([self.table_model_cls.to_experience(exp) for exp in experiences]) + if len(exp_list) < batch_size: + self.logger.info(f"Waiting for {batch_size - len(exp_list)} more experiences...") + time.sleep(1) return exp_list + def _read_priority(self, batch_size: int) -> List[Experience]: + exp_list = [] + start_time = time.time() + latest_size = 0 + while latest_size < batch_size: + if self.stopped: + raise StopIteration() + if time.time() - start_time > self.max_timeout: + self.logger.warning( + f"Max read timeout reached ({self.max_timeout} s), only get {latest_size} experiences, stopping..." + ) + raise StopIteration() + with retry_session( + self.session, self.max_retry_times, self.max_retry_interval + ) as session: + experiences = ( + session.query(self.table_model_cls) + .order_by(asc(self.table_model_cls.consumed), desc(self.table_model_cls.id)) + .limit(batch_size) + .with_for_update() + .all() + ) + if len(experiences) != batch_size: + if latest_size != len(experiences): + latest_size = len(experiences) + start_time = time.time() + else: + ids = [exp.id for exp in experiences] + session.query(self.table_model_cls).filter( + self.table_model_cls.id.in_(ids) + ).update( + {self.table_model_cls.consumed: self.table_model_cls.consumed + 1}, + synchronize_session=False, + ) + exp_list.extend( + [self.table_model_cls.to_experience(exp) for exp in experiences] + ) + break + + self.logger.info(f"Waiting for {batch_size - len(exp_list)} more experiences...") + time.sleep(1) + return exp_list + + def read(self, batch_size: Optional[int] = None) -> List[Experience]: + if self.stopped: + raise StopIteration() + + batch_size = batch_size or self.batch_size + return self._read_method(batch_size) + @classmethod def load_from_dataset( cls, dataset: Dataset, storage_config: StorageConfig, config: BufferConfig @@ -158,6 +215,8 @@ def load_from_dataset( class SQLTaskStorage(SQLStorage): + """Used as explorer input.""" + def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: super().__init__(storage_config, config) self.batch_size = config.batch_size diff --git a/trinity/buffer/writer/sql_writer.py b/trinity/buffer/writer/sql_writer.py index 9ffaa13adb..ebe4ed0267 100644 --- a/trinity/buffer/writer/sql_writer.py +++ b/trinity/buffer/writer/sql_writer.py @@ -25,19 +25,19 @@ def write(self, data: list) -> None: async def write_async(self, data): if self.wrap_in_ray: - await self.db_wrapper.write.remote(data) + ray.get(self.db_wrapper.write.remote(data)) else: self.db_wrapper.write(data) async def acquire(self) -> int: if self.wrap_in_ray: - return await self.db_wrapper.acquire.remote() + return ray.get(self.db_wrapper.acquire.remote()) else: return 0 async def release(self) -> int: if self.wrap_in_ray: - return await self.db_wrapper.release.remote() + return ray.get(self.db_wrapper.release.remote()) else: self.db_wrapper.release() return 0 diff --git a/trinity/common/config.py b/trinity/common/config.py index 2e84530eea..b11a6081dd 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -81,6 +81,9 @@ class StorageConfig: path: Optional[str] = None repeat_times: Optional[int] = None + # For continuing training + index: int = 0 + # used for multi-modal data mm_data_kwargs: dict = field(default_factory=dict) @@ -88,7 +91,6 @@ class StorageConfig: split: str = "train" subset_name: Optional[str] = None format: FormatConfig = field(default_factory=FormatConfig) - index: int = 0 # used for StorageType.QUEUE capacity: int = 10000 diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index b78856442e..242076648c 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -25,7 +25,7 @@ ) from trinity.common.models import create_inference_models from trinity.explorer.scheduler import Scheduler -from trinity.manager.manager import CacheManager +from trinity.manager.state_manager import StateManager from trinity.manager.synchronizer import Synchronizer from trinity.utils.log import get_logger from trinity.utils.monitor import MONITOR, gather_metrics @@ -38,16 +38,16 @@ class Explorer: def __init__(self, config: Config): self.logger = get_logger(config.explorer.name, in_ray_actor=True) load_plugins() - self.cache = CacheManager(config) - explorer_meta = self.cache.load_explorer() - self.explore_step_num = explorer_meta.get("latest_iteration", 0) + self.state = StateManager(config) + explorer_state = self.state.load_explorer() + self.explore_step_num = explorer_state.get("latest_iteration", 0) self.last_sync_step = self.explore_step_num if self.explore_step_num > 0 else -1 self.synchronizer = Synchronizer.get_actor(config) self.config = config self.algorithm_manager = AlgorithmManager(config) self.models, self.auxiliary_models = create_inference_models(config) self.experience_pipeline = self._init_experience_pipeline() - self.config.buffer.explorer_input.taskset.index = explorer_meta.get("latest_task_index", 0) + self.config.buffer.explorer_input.taskset.index = explorer_state.get("latest_task_index", 0) self.taskset = get_buffer_reader( self.config.buffer.explorer_input.taskset, self.config.buffer ) @@ -326,7 +326,7 @@ async def save_checkpoint(self, sync_weight: bool = False) -> None: ) # save explore checkpoint - self.cache.save_explorer( + self.state.save_explorer( current_step=self.explore_step_num, current_task_index=self.explore_step_num * self.config.buffer.batch_size, ) @@ -345,7 +345,6 @@ async def _finish_steps(self, start_step: int, end_step: int, model_version: int async def _finish_explore_step(self, step: int, model_version: int) -> None: statuses, exps = await self.scheduler.get_results(batch_id=step) metric = {"rollout/model_version": model_version} - # TODO: avoid blocking pipeline_metrics = await self.experience_pipeline.process.remote(exps) metric.update(pipeline_metrics) if statuses: diff --git a/trinity/manager/__init__.py b/trinity/manager/__init__.py index 663b1ae799..5b81d2d345 100644 --- a/trinity/manager/__init__.py +++ b/trinity/manager/__init__.py @@ -1,7 +1,7 @@ -from trinity.manager.manager import CacheManager +from trinity.manager.state_manager import StateManager from trinity.manager.synchronizer import Synchronizer __all__ = [ - "CacheManager", + "StateManager", "Synchronizer", ] diff --git a/trinity/manager/manager.py b/trinity/manager/state_manager.py similarity index 68% rename from trinity/manager/manager.py rename to trinity/manager/state_manager.py index f65f5af204..a7722aafa8 100644 --- a/trinity/manager/manager.py +++ b/trinity/manager/state_manager.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -"""Data manager.""" +"""State manager.""" import json import os @@ -7,14 +7,14 @@ from trinity.utils.log import get_logger -class CacheManager: - """A Manager class for managing the cache dir.""" +class StateManager: + """A Manager class for managing the running state of Explorer and Trainer.""" def __init__(self, config: Config, check_config: bool = False): self.logger = get_logger(__name__, in_ray_actor=True) self.cache_dir = config.monitor.cache_dir # type: ignore - self.explorer_meta_path = os.path.join(self.cache_dir, f"{config.explorer.name}_meta.json") # type: ignore - self.trainer_meta_path = os.path.join(self.cache_dir, f"{config.trainer.name}_meta.json") # type: ignore + self.explorer_state_path = os.path.join(self.cache_dir, f"{config.explorer.name}_meta.json") # type: ignore + self.trainer_state_path = os.path.join(self.cache_dir, f"{config.trainer.name}_meta.json") # type: ignore if check_config: self._check_config_consistency(config) @@ -34,7 +34,7 @@ def _check_config_consistency(self, config: Config) -> None: ) def save_explorer(self, current_task_index: int, current_step: int) -> None: - with open(self.explorer_meta_path, "w", encoding="utf-8") as f: + with open(self.explorer_state_path, "w", encoding="utf-8") as f: json.dump( {"latest_task_index": current_task_index, "latest_iteration": current_step}, f, @@ -42,9 +42,9 @@ def save_explorer(self, current_task_index: int, current_step: int) -> None: ) def load_explorer(self) -> dict: - if os.path.exists(self.explorer_meta_path): + if os.path.exists(self.explorer_state_path): try: - with open(self.explorer_meta_path, "r", encoding="utf-8") as f: + with open(self.explorer_state_path, "r", encoding="utf-8") as f: explorer_meta = json.load(f) self.logger.info( "----------------------------------\n" @@ -55,17 +55,21 @@ def load_explorer(self) -> dict: ) return explorer_meta except Exception as e: - self.logger.error(f"Failed to load explore meta file: {e}") + self.logger.error(f"Failed to load explore state file: {e}") return {} - def save_trainer(self, current_step: int) -> None: - with open(self.trainer_meta_path, "w", encoding="utf-8") as f: - json.dump({"latest_iteration": current_step}, f, indent=2) + def save_trainer(self, current_exp_index: int, current_step: int) -> None: + with open(self.trainer_state_path, "w", encoding="utf-8") as f: + json.dump( + {"latest_exp_index": current_exp_index, "latest_iteration": current_step}, + f, + indent=2, + ) def load_trainer(self) -> dict: - if os.path.exists(self.trainer_meta_path): + if os.path.exists(self.trainer_state_path): try: - with open(self.trainer_meta_path, "r", encoding="utf-8") as f: + with open(self.trainer_state_path, "r", encoding="utf-8") as f: trainer_meta = json.load(f) self.logger.info( "----------------------------------\n" @@ -76,5 +80,5 @@ def load_trainer(self) -> dict: ) return trainer_meta except Exception as e: - self.logger.warning(f"Failed to load trainer meta file: {e}") + self.logger.warning(f"Failed to load trainer state file: {e}") return {} diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 9626818525..680d9594e9 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -17,6 +17,7 @@ from trinity.common.config import Config from trinity.common.constants import RunningStatus, SyncMethod, SyncStyle from trinity.common.experience import Experiences +from trinity.manager.state_manager import StateManager from trinity.manager.synchronizer import Synchronizer from trinity.utils.log import get_logger from trinity.utils.monitor import MONITOR @@ -32,6 +33,11 @@ def __init__(self, config: Config) -> None: load_plugins() self.synchronizer = Synchronizer.get_actor(config) self.engine = get_trainer_wrapper(config) + self.state = StateManager(config) + trainer_state = self.state.load_trainer() + config.buffer.trainer_input.experience_buffer.index = trainer_state.get( + "latest_exp_index", 0 + ) self.last_trainer_sync_step = 0 self.monitor = MONITOR.get(config.monitor.monitor_type)( project=config.project, @@ -70,7 +76,7 @@ async def train(self) -> str: self.logger.error(f"Error in Trainer:\n{traceback.format_exc()}") self.train_continue = False - self.engine.save_checkpoint(block_until_saved=True) + self.save_checkpoint(block_until_saved=True) await self.synchronizer.set_trainer_status.remote(RunningStatus.STOPPED) self.logger.info("--------------------\n> Trainer finished.\n--------------------") return self.config.trainer.name @@ -93,7 +99,7 @@ async def train_step(self) -> bool: or self.train_step_num % self.config.trainer.save_interval != 0 ): self.logger.info(f"Saving at step {self.train_step_num}.") - self.engine.save_checkpoint() + self.save_checkpoint() self.logger.info(f"Saved at step {self.train_step_num}.") return False self.logger.info(f"Sampling at step {self.train_step_num + 1} done.") @@ -151,6 +157,13 @@ def _log_experiences(self, samples: List[Dict]) -> None: ) self._sample_exps_to_log.clear() + def save_checkpoint(self, block_until_saved: bool = False) -> None: + self.engine.save_checkpoint(block_until_saved=block_until_saved) + self.state.save_trainer( + current_exp_index=self.engine.train_step_num * self.config.buffer.train_batch_size, + current_step=self.train_step_num, + ) + async def shutdown(self) -> None: self.monitor.close() diff --git a/trinity/trainer/verl/megatron_workers.py b/trinity/trainer/verl/megatron_workers.py index e4431bfe28..a03f5d99ee 100644 --- a/trinity/trainer/verl/megatron_workers.py +++ b/trinity/trainer/verl/megatron_workers.py @@ -87,11 +87,7 @@ def __init__(self, config: DictConfig, role: str, **kwargs): rank = int(os.environ["LOCAL_RANK"]) torch.distributed.init_process_group( backend=get_nccl_backend(), - timeout=datetime.timedelta( - seconds=self.config.get( - "nccl_timeout", seconds=self.config.synchronizer.sync_timeout - ) - ), + timeout=datetime.timedelta(seconds=self.config.synchronizer.sync_timeout), init_method=os.environ.get("DIST_INIT_METHOD", None), ) get_torch_device().set_device(rank) @@ -696,7 +692,7 @@ def __init__(self, config): rank = int(os.environ["LOCAL_RANK"]) torch.distributed.init_process_group( backend=get_nccl_backend(), - timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), + timeout=datetime.timedelta(seconds=self.config.synchronizer.sync_timeout), init_method=os.environ.get("DIST_INIT_METHOD", None), ) get_torch_device().set_device(rank)