diff --git a/tests/common/synchronizer_test.py b/tests/common/synchronizer_test.py index 538f974526..37bfba3929 100644 --- a/tests/common/synchronizer_test.py +++ b/tests/common/synchronizer_test.py @@ -34,10 +34,11 @@ def trainer_monkey_patch(config: Config, max_steps: int, intervals: List[int]): - def new_train_step(self): + async def new_train_step(self): self.engine.algorithm = ALGORITHM_TYPE.get(config.algorithm.algorithm_type) self.engine.global_steps += 1 self.logger.info(f"Training at step {self.engine.global_steps} started.") + await asyncio.sleep(0.1) time.sleep(intervals[self.engine.global_steps - 1]) metrics = {"actor/step": self.engine.global_steps} self.monitor.log(data=metrics, step=self.engine.global_steps) diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index b45560a269..9508fd1dd8 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -44,16 +44,16 @@ def __init__(self, buffer_config: BufferConfig, **kwargs): buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config ) - def sample(self, step: int) -> Tuple[Experiences, Dict, List]: + async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: metrics = {} with Timer(metrics, "read_time"): - usual_exp_list = self.usual_exp_buffer.read() + usual_exp_list = await self.usual_exp_buffer.read_async() for exp in usual_exp_list: if exp.info is None: exp.info = {} exp.info["is_expert"] = False - expert_exp_list = self.expert_exp_buffer.read() + expert_exp_list = await self.expert_exp_buffer.read_async() for exp in expert_exp_list: exp.reward = 0.0 exp.logprobs = torch.zeros_like( diff --git a/trinity/algorithm/sample_strategy/sample_strategy.py b/trinity/algorithm/sample_strategy/sample_strategy.py index b6b3c1e356..4b591ac3c7 100644 --- a/trinity/algorithm/sample_strategy/sample_strategy.py +++ b/trinity/algorithm/sample_strategy/sample_strategy.py @@ -16,7 +16,7 @@ def __init__(self, buffer_config: BufferConfig, **kwargs) -> None: self.pad_token_id = buffer_config.pad_token_id @abstractmethod - def sample(self, step: int) -> Tuple[Experiences, Dict, List]: + async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: """Sample data from buffer. Args: @@ -53,13 +53,13 @@ def __init__(self, buffer_config: BufferConfig, **kwargs): else: self.sft_buffer = None - def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]: + async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]: metrics = {} with Timer(metrics, "read_time"): if step <= self.sft_warmup_steps: - exp_list = self.sft_buffer.read() + exp_list = await self.sft_buffer.read_async() else: - exp_list = self.exp_buffer.read() + exp_list = await self.exp_buffer.read_async() repr_samples = representative_sample(exp_list) with Timer(metrics, "gather_time"): exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore @@ -78,10 +78,10 @@ def __init__(self, buffer_config: BufferConfig, **kwargs): buffer_config.trainer_input.experience_buffer, buffer_config # type: ignore ) - def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]: + async def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]: metrics = {} with Timer(metrics, "read_time"): - exp_list = self.exp_buffer.read() + exp_list = await self.exp_buffer.read_async() repr_samples = representative_sample(exp_list) with Timer(metrics, "gather_time"): exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore diff --git a/trinity/buffer/buffer_reader.py b/trinity/buffer/buffer_reader.py index e5894b7521..7759735ab1 100644 --- a/trinity/buffer/buffer_reader.py +++ b/trinity/buffer/buffer_reader.py @@ -13,3 +13,9 @@ def read( self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None ) -> List: """Read from buffer.""" + + @abstractmethod + async def read_async( + self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None + ) -> List: + """Read from buffer asynchronously.""" diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index a30f58ee8c..ce0d645f22 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -182,6 +182,11 @@ def read( raise ValueError(f"Unknown data format: {self.prompt_type}") return exp_list + async def read_async( + self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None + ): + return self.read(batch_size, strategy) + @FILE_READERS.register_module(DPOAlgorithm.name()) class DPODataReader(BufferReader): @@ -259,6 +264,11 @@ def read( exp_list.append(experience) return exp_list + async def read_async( + self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None + ): + return self.read(batch_size, strategy) + @FILE_READERS.register_module("rollout") class RolloutDataReader(BufferReader): @@ -323,6 +333,11 @@ def read( tasks.append(task) return tasks + async def read_async( + self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None + ): + return self.read(batch_size, strategy) + @FILE_READERS.register_module("raw") class RawDataReader(BufferReader): @@ -340,3 +355,8 @@ def read( raise StopIteration self.returned = True return self.dataset.to_list() + + async def read_async( + self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None + ): + return self.read(batch_size, strategy) diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index 4e59b9e297..0968f0940a 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -37,3 +37,19 @@ def read( except StopAsyncIteration: raise StopIteration() return exps + + async def read_async( + self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None + ) -> List: + if strategy is not None and strategy != ReadStrategy.FIFO: + raise NotImplementedError(f"Read strategy {strategy} not supported for Queue Reader.") + try: + batch_size = batch_size or self.read_batch_size + exps = await self.queue.get_batch.remote(batch_size, timeout=self.timeout) + if len(exps) != batch_size: + raise TimeoutError( + f"Read incomplete batch ({len(exps)}/{batch_size}), please check your workflow." + ) + except StopAsyncIteration: + raise StopIteration() + return exps diff --git a/trinity/buffer/reader/sql_reader.py b/trinity/buffer/reader/sql_reader.py index cc725f842c..0dee66218a 100644 --- a/trinity/buffer/reader/sql_reader.py +++ b/trinity/buffer/reader/sql_reader.py @@ -25,3 +25,11 @@ def read( return ray.get(self.db_wrapper.read.remote(batch_size, strategy)) else: return self.db_wrapper.read(batch_size, strategy) + + async def read_async( + self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None + ) -> List: + if self.wrap_in_ray: + return await self.db_wrapper.read.remote(batch_size, strategy) + else: + return self.db_wrapper.read(batch_size, strategy) diff --git a/trinity/common/synchronizer.py b/trinity/common/synchronizer.py index 3e2070effa..17e9ae307b 100644 --- a/trinity/common/synchronizer.py +++ b/trinity/common/synchronizer.py @@ -184,7 +184,7 @@ async def wait_new_model_state_dict(self, current_version: int, no_wait: bool = ) if self.model_version > current_version: self.set_explorer_status( - RunningStatus.WAITING_SYNC, old_status=RunningStatus.REQUIRE_SYNC + RunningStatus.RUNNING, old_status=RunningStatus.REQUIRE_SYNC ) return self.model_version @@ -237,10 +237,17 @@ def sync_failed(): await asyncio.wait_for( self._ready_condition.wait_for( lambda: self.explorer_status_counts[RunningStatus.WAITING_SYNC] + + self.explorer_status_counts[RunningStatus.STOPPED] == 1, ), timeout=self.config.synchronizer.sync_timeout, ) + if self.explorer_status_counts[RunningStatus.STOPPED] == 1: + return sync_failed() + self.set_explorer_status( + RunningStatus.RUNNING, + old_status=RunningStatus.WAITING_SYNC, + ) elif module == "explorer": self.set_explorer_status( RunningStatus.WAITING_SYNC, old_status=RunningStatus.REQUIRE_SYNC @@ -274,13 +281,19 @@ def get_actor(cls, config: Optional[Config] = None, namespace: Optional[str] = N A reference to the Synchronizer actor. """ if config is not None: + if config.mode == "explore" or ( + config.mode == "train" and config.algorithm.algorithm_type not in {"dpo", "sft"} + ): + lifetime = "detached" + else: + lifetime = None return ( ray.remote(cls) .options( name="synchronizer", namespace=config.ray_namespace, get_if_exists=True, - lifetime="detached", + lifetime=lifetime, ) .remote(config) ) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 4f6428f1ba..6dc88e7fce 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -140,9 +140,6 @@ async def _pull_latest_weights(self): ) self.model_version = new_version self.last_sync_step = self.explore_step_num - await self.synchronizer.set_explorer_status.remote( - RunningStatus.RUNNING, old_status=RunningStatus.WAITING_SYNC - ) self.last_sync_successful = True else: self.logger.warning( @@ -163,9 +160,6 @@ async def _nccl_weights_update(self): *[model.sync_model.remote(self.model_version) for model in self.models] ) self.last_sync_step = self.explore_step_num - await self.synchronizer.set_explorer_status.remote( - RunningStatus.RUNNING, old_status=RunningStatus.WAITING_SYNC - ) self.last_sync_successful = True async def prepare(self) -> None: @@ -183,7 +177,7 @@ async def prepare(self) -> None: ) await asyncio.gather(*futures, return_exceptions=True) if self.config.explorer.eval_on_startup and self.explore_step_num == 0: - self.eval() + await self.eval() await self.synchronizer.set_explorer_status.remote(RunningStatus.REQUIRE_SYNC) async def get_weight(self, name: str) -> torch.Tensor: @@ -210,7 +204,7 @@ async def explore(self) -> str: # TODO: support eval on last checkpoint break if self.need_eval(): - self.eval() + await self.eval() if await self.need_sync(): await self.sync_weight() except Exception: @@ -228,8 +222,10 @@ async def explore_step(self) -> bool: self.explore_step_num += 1 return True try: - tasks = self.taskset.read() - except StopIteration: + tasks = await self.taskset.read_async() + except (StopIteration, RuntimeError) as e: + if isinstance(e, RuntimeError) and "StopIteration" not in str(e): + raise self.logger.warning("No more tasks to explore. Stop exploring.") await self.save_checkpoint(sync_weight=False) await self.synchronizer.set_explorer_status.remote( @@ -270,7 +266,7 @@ async def need_sync(self) -> bool: def need_eval(self) -> bool: return self.explore_step_num % self.config.explorer.eval_interval == 0 - def eval(self): + async def eval(self): """Evaluation on all evaluation data samples.""" if len(self.config.buffer.explorer_input.eval_tasksets) == 0: self.logger.warning("No evaluation data samples. Skip evaluation.") @@ -285,8 +281,11 @@ def eval(self): self.pending_eval_tasks.append((self.explore_step_num, eval_taskset.name)) while True: try: - self.scheduler.schedule(eval_taskset.read(), batch_id=eval_batch_id) - except StopIteration: + data = await eval_taskset.read_async() + self.scheduler.schedule(data, batch_id=eval_batch_id) + except (StopIteration, RuntimeError) as e: + if isinstance(e, RuntimeError) and "StopIteration" not in str(e): + raise break async def benchmark(self) -> bool: @@ -294,7 +293,7 @@ async def benchmark(self) -> bool: # benchmark on the latest checkpoint if self.config.explorer.bench_on_latest_checkpoint: self.explore_step_num = await self._checkpoint_weights_update() - self.eval() + await self.eval() await self._finish_eval_step(prefix="bench") return True @@ -313,7 +312,7 @@ async def benchmark(self) -> bool: ) for step_num in all_ckp_steps: self.explore_step_num = await self._checkpoint_weights_update(step_num=step_num) - self.eval() + await self.eval() await self._finish_eval_step(prefix="bench") return True @@ -391,8 +390,8 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva self.monitor.log(metric, step) async def shutdown(self) -> None: + await self.scheduler.stop() self.monitor.close() if await self.synchronizer.release.remote() == 0: ray.kill(self.synchronizer) self.logger.info("Synchronizer stopped.") - await self.scheduler.stop() diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 00de93c10d..0bc6c076f1 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -4,6 +4,7 @@ """ from __future__ import annotations +import asyncio import traceback from abc import ABC, abstractmethod from typing import Dict, List, Tuple @@ -43,41 +44,50 @@ def __init__(self, config: Config) -> None: buffer_config=config.buffer, **config.algorithm.sample_strategy_args, ) + self.train_continue = True + self.last_sync_step = None def prepare(self) -> None: """Prepare the trainer.""" self.engine.prepare() - self.last_trainer_sync_step = self.engine.train_step_num + self.last_trainer_sync_step = self.train_step_num ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING)) - def train(self) -> str: + async def train(self) -> str: """Train the model.""" - while True: + while self.train_continue: try: - train_continue = self.train_step() - if not train_continue: - break - if self.need_sync(): + train_task = asyncio.create_task(self.train_step()) + while not train_task.done(): + if self.need_sync(): + self.sync_weight() + await asyncio.sleep(1) + self.train_continue &= await train_task + if self.train_continue and self.need_sync(): self.sync_weight() except Exception: self.logger.error(f"Error in Trainer:\n{traceback.format_exc()}") - break - ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.STOPPED)) + self.train_continue = False + self.engine.save_checkpoint(block_until_saved=True) + await self.synchronizer.set_trainer_status.remote(RunningStatus.STOPPED) self.logger.info("--------------------\n> Trainer finished.\n--------------------") return self.config.trainer.name - def train_step(self) -> bool: + async def train_step(self) -> bool: """Train one step. Returns: bool: Whether to continue training. """ + self.logger.info(f"Training at step {self.train_step_num + 1} started.") try: - batch, sample_metrics, repr_samples = self.sample_strategy.sample( + batch, sample_metrics, repr_samples = await self.sample_strategy.sample( self.train_step_num + 1 ) - except StopIteration: + except (StopIteration, RuntimeError) as e: + if isinstance(e, RuntimeError) and "StopIteration" not in str(e): + raise self.logger.info("No more samples to train. Stopping training.") if ( self.config.trainer.save_interval == 0 @@ -87,7 +97,9 @@ def train_step(self) -> bool: self.engine.save_checkpoint() self.logger.info(f"Saved at step {self.train_step_num}.") return False + self.logger.info(f"Sampling at step {self.train_step_num + 1} done.") continue_run, metrics = self.engine.train_step(batch) + self.logger.info(f"Training at step {self.train_step_num} finished.") prefix_metrics(sample_metrics, "sample", metrics) self.monitor.log(data=metrics, step=self.train_step_num) if self.config.trainer.enable_preview: @@ -97,10 +109,13 @@ def train_step(self) -> bool: def need_sync(self) -> bool: """Whether to sync the model weight.""" if self.config.synchronizer.sync_style == SyncStyle.FIXED: - return self.engine.train_step_num % self.config.synchronizer.sync_interval == 0 + return ( + self.last_sync_step != self.train_step_num + and self.train_step_num % self.config.synchronizer.sync_interval == 0 + ) else: if self.config.synchronizer.sync_style == SyncStyle.DYNAMIC_BY_TRAINER: - delta = self.engine.train_step_num - self.last_trainer_sync_step + delta = self.train_step_num - self.last_trainer_sync_step if delta >= self.config.synchronizer.sync_interval: ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.REQUIRE_SYNC)) explorer_status_counts = ray.get(self.synchronizer.get_explorer_status_counts.remote()) @@ -111,23 +126,22 @@ def need_sync(self) -> bool: def sync_weight(self) -> None: """Sync the model weight.""" - self.logger.info( - f"Trainer synchronizing weights at step {self.engine.train_step_num} starting.." - ) + self.logger.info(f"Trainer synchronizing weights at step {self.train_step_num} starting..") if self.config.synchronizer.sync_method == SyncMethod.NCCL: result = ray.get( - self.synchronizer.ready_to_nccl_sync.remote("trainer", self.engine.train_step_num) + self.synchronizer.ready_to_nccl_sync.remote("trainer", self.train_step_num) ) if result is None: self.logger.error("Trainer synchronizing weights failed.") else: self.engine.sync_weight() - self.last_trainer_sync_step = self.engine.train_step_num + self.last_trainer_sync_step = self.train_step_num elif self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT: self.engine.save_state_dict() elif self.config.synchronizer.sync_method == SyncMethod.MEMORY: self.engine.upload_state_dict() - self.logger.info(f"Trainer synchronizing weights at step {self.engine.train_step_num} end.") + self.logger.info(f"Trainer synchronizing weights at step {self.train_step_num} end.") + self.last_sync_step = self.train_step_num ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING)) def _log_experiences(self, samples: List[Dict]) -> None: @@ -138,9 +152,9 @@ def _log_experiences(self, samples: List[Dict]) -> None: ) self._sample_exps_to_log.clear() - def shutdown(self) -> None: + async def shutdown(self) -> None: self.monitor.close() - if ray.get(self.synchronizer.release.remote()) == 0: + if await self.synchronizer.release.remote() == 0: ray.kill(self.synchronizer) self.logger.info("Synchronizer stopped.") diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index e85c8ef540..27edf701e4 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -593,7 +593,9 @@ def setup_weight_sync_group(self): master_address, master_port = self.get_availale_master_addr_port() world_size = self.config.synchronizer.explorer_world_size + 1 print(f"Trainer init_process_group {master_address}:{master_port} ({world_size}).") - synchronizer = Synchronizer.get_actor(self.config.synchronizer) + synchronizer = Synchronizer.get_actor( + namespace=self.config.synchronizer.ray_namespace + ) setup_ref = synchronizer.setup_weight_sync_group.remote( master_address, master_port, self.state_dict_meta ) @@ -837,8 +839,6 @@ def save_checkpoint( @register(dispatch_mode=Dispatch.ONE_TO_ALL) def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False): - print(f" {self._is_actor=} and {self._is_ref=}") - if self._is_actor and self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 2f50b18ca2..59453592d2 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -285,11 +285,9 @@ def upload_state_dict(self): # state dict sync self.actor_rollout_wg.upload_state_dict(self.global_steps) def train_step(self, batch: Experiences) -> Tuple[bool, Dict]: # noqa C901 - self.logger.info(f"Training at step {self.global_steps + 1} started.") batch = to_data_proto(batch) metrics = {} self.global_steps += 1 - self.logger.info(f"Sampling at step {self.global_steps} done.") timing_raw = {} algorithm_config = self.algorithm_manager.get_current_algorithm_config(self.global_steps) algorithm = ALGORITHM_TYPE.get(algorithm_config.algorithm_type) @@ -369,7 +367,6 @@ def train_step(self, batch: Experiences) -> Tuple[bool, Dict]: # noqa C901 with marked_timer("save_checkpoint", timing_raw): self.save_checkpoint() self.logger.info(f"Saved at step {self.global_steps}.") - self.logger.info(f"Training at step {self.global_steps} finished.") return train_status, metrics def save_checkpoint(self, block_until_saved: bool = False) -> None: