Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions tests/common/synchronizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,79 @@ def tearDown(self):
shutil.rmtree(os.path.join(checkpoint_path, "unittest"))


class TestSynchronizerExit(BaseTestSynchronizer):
def test_synchronizer(self):
config = get_template_config()
config.project = "unittest"
config.name = f"test_synchronizer_{datetime.now().strftime('%Y%m%d%H%M%S')}"
config.checkpoint_root_dir = get_checkpoint_path()
config.buffer.total_epochs = 1
config.buffer.batch_size = 4
config.cluster.gpu_per_node = 2
config.cluster.node_num = 1
config.model.model_path = get_model_path()
config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
config.buffer.trainer_input.experience_buffer = StorageConfig(
name="exp_buffer",
storage_type=StorageType.QUEUE,
wrap_in_ray=True,
)
config.synchronizer.sync_method = SyncMethod.CHECKPOINT
config.synchronizer.sync_style = SyncStyle.DYNAMIC_BY_EXPLORER
config.synchronizer.sync_interval = 2
config.trainer.save_interval = 100
config.monitor.monitor_type = "tensorboard"
trainer_config = deepcopy(config)
trainer_config.mode = "train"
trainer_config.check_and_update()

explorer1_config = deepcopy(config)
explorer1_config.mode = "explore"
explorer1_config.explorer.name = "explorer1"
explorer1_config.explorer.rollout_model.engine_num = 1
explorer1_config.explorer.rollout_model.tensor_parallel_size = 1
explorer1_config.buffer.explorer_output = StorageConfig(
name="exp_buffer",
storage_type=StorageType.QUEUE,
wrap_in_ray=True,
)
explorer1_config.check_and_update()

trainer_process = multiprocessing.Process(
target=run_trainer, args=(trainer_config, 8, [2, 1, 2, 1, 2, 1, 2, 1])
)
trainer_process.start()
ray.init(ignore_reinit_error=True)
while True:
try:
synchronizer = ray.get_actor("synchronizer", namespace=trainer_config.ray_namespace)
break
except ValueError:
print("waiting for trainer to start.")
time.sleep(5)
explorer_process_1 = multiprocessing.Process(
target=run_explorer,
args=(explorer1_config, 8, [0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5]),
)
explorer_process_1.start()

self.assertEqual(
synchronizer, ray.get_actor("synchronizer", namespace=trainer_config.ray_namespace)
)
time.sleep(5)
trainer_process.terminate()
trainer_process.join()
self.assertEqual(
synchronizer, ray.get_actor("synchronizer", namespace=trainer_config.ray_namespace)
)

explorer_process_1.terminate()
explorer_process_1.join()
time.sleep(6)
with self.assertRaises(ValueError):
ray.get_actor("synchronizer", namespace=trainer_config.ray_namespace)


@parameterized_class(
(
"sync_method",
Expand Down
47 changes: 30 additions & 17 deletions trinity/common/synchronizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Synchronizer:
checkpoint_shard_counter: Tracks how many shards are received from trainer for a specific train step.
"""

def __init__(self, config: Config):
def __init__(self, config: Config, module_ref: ray.actor.ActorHandle):
self.logger = get_logger(__name__)
self.config = config
self.trainer_status = RunningStatus.STOPPED
Expand All @@ -40,6 +40,30 @@ def __init__(self, config: Config):
self.model_version = 0
self.checkpoint_shard_counter = defaultdict(lambda: 0)
self.ref_count = 0
self._modules = {module_ref}
asyncio.create_task(self._check_modules())

def add_module(self, module_ref: ray.actor.ActorHandle) -> None:
"""Adds a module to be tracked by the synchronizer.

Args:
module_ref: The Ray actor handle of the module to track.
"""
self._modules.add(module_ref)

async def _check_modules(self) -> None:
while len(self._modules) > 0:
alive_modules = set()
for module in self._modules:
try:
await module.is_alive.remote()
alive_modules.add(module)
except ray.exceptions.RayActorError:
pass
self._modules = alive_modules
await asyncio.sleep(5)
self.logger.info("Synchronizer stopped.")
ray.actor.exit_actor()

async def set_trainer_status(self, status: RunningStatus):
"""Update the status of the trainer."""
Expand Down Expand Up @@ -281,27 +305,16 @@ def get_actor(cls, config: Optional[Config] = None, namespace: Optional[str] = N
A reference to the Synchronizer actor.
"""
if config is not None:
if config.mode == "explore" or (
config.mode == "train" and config.algorithm.algorithm_type not in {"dpo", "sft"}
):
lifetime = "detached"
else:
lifetime = None
return (
module_ref = ray.get_runtime_context().current_actor
synchronizer = (
ray.remote(cls)
.options(
name="synchronizer",
namespace=config.ray_namespace,
get_if_exists=True,
lifetime=lifetime,
lifetime="detached",
)
.remote(config)
.remote(config, module_ref=module_ref)
)
ray.get(synchronizer.add_module.remote(module_ref))
return ray.get_actor("synchronizer", namespace=namespace)

def acquire(self):
self.ref_count += 1

def release(self):
self.ref_count -= 1
return self.ref_count
11 changes: 5 additions & 6 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from collections import deque
from typing import List, Optional

import ray
import torch

from trinity.algorithm import ADD_STRATEGY
Expand Down Expand Up @@ -166,10 +165,9 @@ async def prepare(self) -> None:
"""Preparation before running."""
futures = [
asyncio.create_task(self.scheduler.start()),
self.synchronizer.acquire.remote(),
]
if self.experience_buffer:
futures.append(asyncio.create_task(self.experience_buffer.acquire()))
futures.append(asyncio.create_task(self.experience_buffer.acquire())) # type: ignore
if not self.use_nccl_sync:
master_address, master_port = await self.models[0].get_available_address.remote()
futures.append(
Expand Down Expand Up @@ -398,6 +396,7 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva
async def shutdown(self) -> None:
await self.scheduler.stop()
self.monitor.close()
if await self.synchronizer.release.remote() == 0:
ray.kill(self.synchronizer)
self.logger.info("Synchronizer stopped.")

def is_alive(self) -> bool:
"""Check if the explorer is alive."""
return True
8 changes: 4 additions & 4 deletions trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def __init__(self, config: Config) -> None:
self.config = config
self.logger = get_logger(__name__)
self.synchronizer = Synchronizer.get_actor(config)
ray.get(self.synchronizer.acquire.remote())
self.engine = get_trainer_wrapper(config)
self.last_trainer_sync_step = 0
self.monitor = MONITOR.get(config.monitor.monitor_type)(
Expand Down Expand Up @@ -154,15 +153,16 @@ def _log_experiences(self, samples: List[Dict]) -> None:

async def shutdown(self) -> None:
self.monitor.close()
if await self.synchronizer.release.remote() == 0:
ray.kill(self.synchronizer)
self.logger.info("Synchronizer stopped.")

@property
def train_step_num(self) -> int:
"""Get the current training step number."""
return self.engine.train_step_num

def is_alive(self) -> bool:
"""Check if the trainer is alive."""
return True


class TrainEngineWrapper(ABC):
"""A wrapper class to wrap various training engines."""
Expand Down