diff --git a/tests/common/synchronizer_test.py b/tests/common/synchronizer_test.py index 37bfba3929..b341f1d982 100644 --- a/tests/common/synchronizer_test.py +++ b/tests/common/synchronizer_test.py @@ -105,6 +105,79 @@ def tearDown(self): shutil.rmtree(os.path.join(checkpoint_path, "unittest")) +class TestSynchronizerExit(BaseTestSynchronizer): + def test_synchronizer(self): + config = get_template_config() + config.project = "unittest" + config.name = f"test_synchronizer_{datetime.now().strftime('%Y%m%d%H%M%S')}" + config.checkpoint_root_dir = get_checkpoint_path() + config.buffer.total_epochs = 1 + config.buffer.batch_size = 4 + config.cluster.gpu_per_node = 2 + config.cluster.node_num = 1 + config.model.model_path = get_model_path() + config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") + config.buffer.trainer_input.experience_buffer = StorageConfig( + name="exp_buffer", + storage_type=StorageType.QUEUE, + wrap_in_ray=True, + ) + config.synchronizer.sync_method = SyncMethod.CHECKPOINT + config.synchronizer.sync_style = SyncStyle.DYNAMIC_BY_EXPLORER + config.synchronizer.sync_interval = 2 + config.trainer.save_interval = 100 + config.monitor.monitor_type = "tensorboard" + trainer_config = deepcopy(config) + trainer_config.mode = "train" + trainer_config.check_and_update() + + explorer1_config = deepcopy(config) + explorer1_config.mode = "explore" + explorer1_config.explorer.name = "explorer1" + explorer1_config.explorer.rollout_model.engine_num = 1 + explorer1_config.explorer.rollout_model.tensor_parallel_size = 1 + explorer1_config.buffer.explorer_output = StorageConfig( + name="exp_buffer", + storage_type=StorageType.QUEUE, + wrap_in_ray=True, + ) + explorer1_config.check_and_update() + + trainer_process = multiprocessing.Process( + target=run_trainer, args=(trainer_config, 8, [2, 1, 2, 1, 2, 1, 2, 1]) + ) + trainer_process.start() + ray.init(ignore_reinit_error=True) + while True: + try: + synchronizer = ray.get_actor("synchronizer", namespace=trainer_config.ray_namespace) + break + except ValueError: + print("waiting for trainer to start.") + time.sleep(5) + explorer_process_1 = multiprocessing.Process( + target=run_explorer, + args=(explorer1_config, 8, [0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5]), + ) + explorer_process_1.start() + + self.assertEqual( + synchronizer, ray.get_actor("synchronizer", namespace=trainer_config.ray_namespace) + ) + time.sleep(5) + trainer_process.terminate() + trainer_process.join() + self.assertEqual( + synchronizer, ray.get_actor("synchronizer", namespace=trainer_config.ray_namespace) + ) + + explorer_process_1.terminate() + explorer_process_1.join() + time.sleep(6) + with self.assertRaises(ValueError): + ray.get_actor("synchronizer", namespace=trainer_config.ray_namespace) + + @parameterized_class( ( "sync_method", diff --git a/trinity/common/synchronizer.py b/trinity/common/synchronizer.py index 17e9ae307b..1d175d773d 100644 --- a/trinity/common/synchronizer.py +++ b/trinity/common/synchronizer.py @@ -30,7 +30,7 @@ class Synchronizer: checkpoint_shard_counter: Tracks how many shards are received from trainer for a specific train step. """ - def __init__(self, config: Config): + def __init__(self, config: Config, module_ref: ray.actor.ActorHandle): self.logger = get_logger(__name__) self.config = config self.trainer_status = RunningStatus.STOPPED @@ -40,6 +40,30 @@ def __init__(self, config: Config): self.model_version = 0 self.checkpoint_shard_counter = defaultdict(lambda: 0) self.ref_count = 0 + self._modules = {module_ref} + asyncio.create_task(self._check_modules()) + + def add_module(self, module_ref: ray.actor.ActorHandle) -> None: + """Adds a module to be tracked by the synchronizer. + + Args: + module_ref: The Ray actor handle of the module to track. + """ + self._modules.add(module_ref) + + async def _check_modules(self) -> None: + while len(self._modules) > 0: + alive_modules = set() + for module in self._modules: + try: + await module.is_alive.remote() + alive_modules.add(module) + except ray.exceptions.RayActorError: + pass + self._modules = alive_modules + await asyncio.sleep(5) + self.logger.info("Synchronizer stopped.") + ray.actor.exit_actor() async def set_trainer_status(self, status: RunningStatus): """Update the status of the trainer.""" @@ -281,27 +305,16 @@ def get_actor(cls, config: Optional[Config] = None, namespace: Optional[str] = N A reference to the Synchronizer actor. """ if config is not None: - if config.mode == "explore" or ( - config.mode == "train" and config.algorithm.algorithm_type not in {"dpo", "sft"} - ): - lifetime = "detached" - else: - lifetime = None - return ( + module_ref = ray.get_runtime_context().current_actor + synchronizer = ( ray.remote(cls) .options( name="synchronizer", namespace=config.ray_namespace, get_if_exists=True, - lifetime=lifetime, + lifetime="detached", ) - .remote(config) + .remote(config, module_ref=module_ref) ) + ray.get(synchronizer.add_module.remote(module_ref)) return ray.get_actor("synchronizer", namespace=namespace) - - def acquire(self): - self.ref_count += 1 - - def release(self): - self.ref_count -= 1 - return self.ref_count diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index ac954418ca..6e1dd28ea2 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -9,7 +9,6 @@ from collections import deque from typing import List, Optional -import ray import torch from trinity.algorithm import ADD_STRATEGY @@ -166,10 +165,9 @@ async def prepare(self) -> None: """Preparation before running.""" futures = [ asyncio.create_task(self.scheduler.start()), - self.synchronizer.acquire.remote(), ] if self.experience_buffer: - futures.append(asyncio.create_task(self.experience_buffer.acquire())) + futures.append(asyncio.create_task(self.experience_buffer.acquire())) # type: ignore if not self.use_nccl_sync: master_address, master_port = await self.models[0].get_available_address.remote() futures.append( @@ -398,6 +396,7 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva async def shutdown(self) -> None: await self.scheduler.stop() self.monitor.close() - if await self.synchronizer.release.remote() == 0: - ray.kill(self.synchronizer) - self.logger.info("Synchronizer stopped.") + + def is_alive(self) -> bool: + """Check if the explorer is alive.""" + return True diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 0bc6c076f1..be6c943581 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -29,7 +29,6 @@ def __init__(self, config: Config) -> None: self.config = config self.logger = get_logger(__name__) self.synchronizer = Synchronizer.get_actor(config) - ray.get(self.synchronizer.acquire.remote()) self.engine = get_trainer_wrapper(config) self.last_trainer_sync_step = 0 self.monitor = MONITOR.get(config.monitor.monitor_type)( @@ -154,15 +153,16 @@ def _log_experiences(self, samples: List[Dict]) -> None: async def shutdown(self) -> None: self.monitor.close() - if await self.synchronizer.release.remote() == 0: - ray.kill(self.synchronizer) - self.logger.info("Synchronizer stopped.") @property def train_step_num(self) -> int: """Get the current training step number.""" return self.engine.train_step_num + def is_alive(self) -> bool: + """Check if the trainer is alive.""" + return True + class TrainEngineWrapper(ABC): """A wrapper class to wrap various training engines."""