Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/common/synchronizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@


def trainer_monkey_patch(config: Config, max_steps: int, intervals: List[int]):
def new_train_step(self):
async def new_train_step(self):
self.engine.algorithm = ALGORITHM_TYPE.get(config.algorithm.algorithm_type)
self.engine.global_steps += 1
self.logger.info(f"Training at step {self.engine.global_steps} started.")
await asyncio.sleep(0.1)
time.sleep(intervals[self.engine.global_steps - 1])
metrics = {"actor/step": self.engine.global_steps}
self.monitor.log(data=metrics, step=self.engine.global_steps)
Expand Down
6 changes: 3 additions & 3 deletions trinity/algorithm/sample_strategy/mix_sample_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,16 @@ def __init__(self, buffer_config: BufferConfig, **kwargs):
buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config
)

def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
metrics = {}
with Timer(metrics, "read_time"):
usual_exp_list = self.usual_exp_buffer.read()
usual_exp_list = await self.usual_exp_buffer.read_async()
for exp in usual_exp_list:
if exp.info is None:
exp.info = {}
exp.info["is_expert"] = False

expert_exp_list = self.expert_exp_buffer.read()
expert_exp_list = await self.expert_exp_buffer.read_async()
for exp in expert_exp_list:
exp.reward = 0.0
exp.logprobs = torch.zeros_like(
Expand Down
12 changes: 6 additions & 6 deletions trinity/algorithm/sample_strategy/sample_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, buffer_config: BufferConfig, **kwargs) -> None:
self.pad_token_id = buffer_config.pad_token_id

@abstractmethod
def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
"""Sample data from buffer.

Args:
Expand Down Expand Up @@ -53,13 +53,13 @@ def __init__(self, buffer_config: BufferConfig, **kwargs):
else:
self.sft_buffer = None

def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]:
async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]:
metrics = {}
with Timer(metrics, "read_time"):
if step <= self.sft_warmup_steps:
exp_list = self.sft_buffer.read()
exp_list = await self.sft_buffer.read_async()
else:
exp_list = self.exp_buffer.read()
exp_list = await self.exp_buffer.read_async()
repr_samples = representative_sample(exp_list)
with Timer(metrics, "gather_time"):
exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore
Expand All @@ -78,10 +78,10 @@ def __init__(self, buffer_config: BufferConfig, **kwargs):
buffer_config.trainer_input.experience_buffer, buffer_config # type: ignore
)

def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]:
async def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]:
metrics = {}
with Timer(metrics, "read_time"):
exp_list = self.exp_buffer.read()
exp_list = await self.exp_buffer.read_async()
repr_samples = representative_sample(exp_list)
with Timer(metrics, "gather_time"):
exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore
Expand Down
6 changes: 6 additions & 0 deletions trinity/buffer/buffer_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,9 @@ def read(
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
) -> List:
"""Read from buffer."""

@abstractmethod
async def read_async(
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
) -> List:
"""Read from buffer asynchronously."""
20 changes: 20 additions & 0 deletions trinity/buffer/reader/file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,11 @@ def read(
raise ValueError(f"Unknown data format: {self.prompt_type}")
return exp_list

async def read_async(
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
):
return self.read(batch_size, strategy)


@FILE_READERS.register_module(DPOAlgorithm.name())
class DPODataReader(BufferReader):
Expand Down Expand Up @@ -259,6 +264,11 @@ def read(
exp_list.append(experience)
return exp_list

async def read_async(
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
):
return self.read(batch_size, strategy)


@FILE_READERS.register_module("rollout")
class RolloutDataReader(BufferReader):
Expand Down Expand Up @@ -323,6 +333,11 @@ def read(
tasks.append(task)
return tasks

async def read_async(
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
):
return self.read(batch_size, strategy)


@FILE_READERS.register_module("raw")
class RawDataReader(BufferReader):
Expand All @@ -340,3 +355,8 @@ def read(
raise StopIteration
self.returned = True
return self.dataset.to_list()

async def read_async(
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
):
return self.read(batch_size, strategy)
16 changes: 16 additions & 0 deletions trinity/buffer/reader/queue_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,19 @@ def read(
except StopAsyncIteration:
raise StopIteration()
return exps

async def read_async(
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
) -> List:
if strategy is not None and strategy != ReadStrategy.FIFO:
raise NotImplementedError(f"Read strategy {strategy} not supported for Queue Reader.")
try:
batch_size = batch_size or self.read_batch_size
exps = await self.queue.get_batch.remote(batch_size, timeout=self.timeout)
if len(exps) != batch_size:
raise TimeoutError(
f"Read incomplete batch ({len(exps)}/{batch_size}), please check your workflow."
)
except StopAsyncIteration:
raise StopIteration()
return exps
8 changes: 8 additions & 0 deletions trinity/buffer/reader/sql_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,11 @@ def read(
return ray.get(self.db_wrapper.read.remote(batch_size, strategy))
else:
return self.db_wrapper.read(batch_size, strategy)

async def read_async(
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
) -> List:
if self.wrap_in_ray:
return await self.db_wrapper.read.remote(batch_size, strategy)
else:
return self.db_wrapper.read(batch_size, strategy)
17 changes: 15 additions & 2 deletions trinity/common/synchronizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ async def wait_new_model_state_dict(self, current_version: int, no_wait: bool =
)
if self.model_version > current_version:
self.set_explorer_status(
RunningStatus.WAITING_SYNC, old_status=RunningStatus.REQUIRE_SYNC
RunningStatus.RUNNING, old_status=RunningStatus.REQUIRE_SYNC
)
return self.model_version

Expand Down Expand Up @@ -237,10 +237,17 @@ def sync_failed():
await asyncio.wait_for(
self._ready_condition.wait_for(
lambda: self.explorer_status_counts[RunningStatus.WAITING_SYNC]
+ self.explorer_status_counts[RunningStatus.STOPPED]
== 1,
),
timeout=self.config.synchronizer.sync_timeout,
)
if self.explorer_status_counts[RunningStatus.STOPPED] == 1:
return sync_failed()
self.set_explorer_status(
RunningStatus.RUNNING,
old_status=RunningStatus.WAITING_SYNC,
)
elif module == "explorer":
self.set_explorer_status(
RunningStatus.WAITING_SYNC, old_status=RunningStatus.REQUIRE_SYNC
Expand Down Expand Up @@ -274,13 +281,19 @@ def get_actor(cls, config: Optional[Config] = None, namespace: Optional[str] = N
A reference to the Synchronizer actor.
"""
if config is not None:
if config.mode == "explore" or (
config.mode == "train" and config.algorithm.algorithm_type not in {"dpo", "sft"}
):
lifetime = "detached"
else:
lifetime = None
return (
ray.remote(cls)
.options(
name="synchronizer",
namespace=config.ray_namespace,
get_if_exists=True,
lifetime="detached",
lifetime=lifetime,
)
.remote(config)
)
Expand Down
31 changes: 15 additions & 16 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,6 @@ async def _pull_latest_weights(self):
)
self.model_version = new_version
self.last_sync_step = self.explore_step_num
await self.synchronizer.set_explorer_status.remote(
RunningStatus.RUNNING, old_status=RunningStatus.WAITING_SYNC
)
self.last_sync_successful = True
else:
self.logger.warning(
Expand All @@ -163,9 +160,6 @@ async def _nccl_weights_update(self):
*[model.sync_model.remote(self.model_version) for model in self.models]
)
self.last_sync_step = self.explore_step_num
await self.synchronizer.set_explorer_status.remote(
RunningStatus.RUNNING, old_status=RunningStatus.WAITING_SYNC
)
self.last_sync_successful = True

async def prepare(self) -> None:
Expand All @@ -183,7 +177,7 @@ async def prepare(self) -> None:
)
await asyncio.gather(*futures, return_exceptions=True)
if self.config.explorer.eval_on_startup and self.explore_step_num == 0:
self.eval()
await self.eval()
await self.synchronizer.set_explorer_status.remote(RunningStatus.REQUIRE_SYNC)

async def get_weight(self, name: str) -> torch.Tensor:
Expand All @@ -210,7 +204,7 @@ async def explore(self) -> str:
# TODO: support eval on last checkpoint
break
if self.need_eval():
self.eval()
await self.eval()
if await self.need_sync():
await self.sync_weight()
except Exception:
Expand All @@ -228,8 +222,10 @@ async def explore_step(self) -> bool:
self.explore_step_num += 1
return True
try:
tasks = self.taskset.read()
except StopIteration:
tasks = await self.taskset.read_async()
except (StopIteration, RuntimeError) as e:
if isinstance(e, RuntimeError) and "StopIteration" not in str(e):
raise
self.logger.warning("No more tasks to explore. Stop exploring.")
await self.save_checkpoint(sync_weight=False)
await self.synchronizer.set_explorer_status.remote(
Expand Down Expand Up @@ -270,7 +266,7 @@ async def need_sync(self) -> bool:
def need_eval(self) -> bool:
return self.explore_step_num % self.config.explorer.eval_interval == 0

def eval(self):
async def eval(self):
"""Evaluation on all evaluation data samples."""
if len(self.config.buffer.explorer_input.eval_tasksets) == 0:
self.logger.warning("No evaluation data samples. Skip evaluation.")
Expand All @@ -285,16 +281,19 @@ def eval(self):
self.pending_eval_tasks.append((self.explore_step_num, eval_taskset.name))
while True:
try:
self.scheduler.schedule(eval_taskset.read(), batch_id=eval_batch_id)
except StopIteration:
data = await eval_taskset.read_async()
self.scheduler.schedule(data, batch_id=eval_batch_id)
except (StopIteration, RuntimeError) as e:
if isinstance(e, RuntimeError) and "StopIteration" not in str(e):
raise
break

async def benchmark(self) -> bool:
"""Benchmark the model checkpoints."""
# benchmark on the latest checkpoint
if self.config.explorer.bench_on_latest_checkpoint:
self.explore_step_num = await self._checkpoint_weights_update()
self.eval()
await self.eval()
await self._finish_eval_step(prefix="bench")
return True

Expand All @@ -313,7 +312,7 @@ async def benchmark(self) -> bool:
)
for step_num in all_ckp_steps:
self.explore_step_num = await self._checkpoint_weights_update(step_num=step_num)
self.eval()
await self.eval()
await self._finish_eval_step(prefix="bench")
return True

Expand Down Expand Up @@ -391,8 +390,8 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva
self.monitor.log(metric, step)

async def shutdown(self) -> None:
await self.scheduler.stop()
self.monitor.close()
if await self.synchronizer.release.remote() == 0:
ray.kill(self.synchronizer)
self.logger.info("Synchronizer stopped.")
await self.scheduler.stop()
Loading