diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 678be80a28..98e6a66c58 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -457,6 +457,23 @@ async def test_split_tasks(self): await scheduler.stop() + async def test_multi_step_execution(self): + self.config.explorer.max_repeat_times_per_runner = 1 + self.config.check_and_update() + scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) + await scheduler.start() + tasks = generate_tasks(2, repeat_times=4) + + n_steps = 3 + for i in range(1, n_steps + 1): + scheduler.schedule(tasks, batch_id=i) + statuses, exps = await scheduler.get_results(batch_id=i) + self.assertEqual(len(statuses), 2 * 4) + exps = self.queue.read(batch_size=2 * 4) + self.assertEqual(len(exps), 2 * 4) + + await scheduler.stop() + def tearDown(self): try: ray.shutdown() diff --git a/trinity/algorithm/add_strategy/add_strategy.py b/trinity/algorithm/add_strategy/add_strategy.py index cdff571ca0..c9083ea2ec 100644 --- a/trinity/algorithm/add_strategy/add_strategy.py +++ b/trinity/algorithm/add_strategy/add_strategy.py @@ -110,7 +110,7 @@ def calculate_group_advantage( group_reward_mean = torch.tensor(0.0) group_reward_std = torch.tensor(1.0) else: - rewards = torch.tensor([exp.reward for exp in exps]) + rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32) group_reward_mean = torch.mean(rewards) group_reward_std = torch.std(rewards) for exp in exps: @@ -155,7 +155,7 @@ def calculate_group_advantage( if len(exps) == 1: group_baseline = torch.tensor(0.0) else: - group_rewards = torch.tensor([exp.reward for exp in exps]) + group_rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32) if self.opmd_baseline == "mean": group_baseline = torch.mean(group_rewards) else: diff --git a/trinity/algorithm/advantage_fn/grpo_advantage.py b/trinity/algorithm/advantage_fn/grpo_advantage.py index 553af6d065..1d7dc1c858 100644 --- a/trinity/algorithm/advantage_fn/grpo_advantage.py +++ b/trinity/algorithm/advantage_fn/grpo_advantage.py @@ -59,8 +59,8 @@ def __call__( id2mean[idx] = torch.tensor(0.0) id2std[idx] = torch.tensor(1.0) elif len(id2score[idx]) > 1: - id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) - id2std[idx] = torch.std(torch.tensor([id2score[idx]])) + id2mean[idx] = torch.mean(torch.tensor(id2score[idx], dtype=torch.float32)) + id2std[idx] = torch.std(torch.tensor(id2score[idx], dtype=torch.float32)) else: raise ValueError(f"no score in prompt index: {idx}") for i in range(bsz): diff --git a/trinity/algorithm/advantage_fn/opmd_advantage.py b/trinity/algorithm/advantage_fn/opmd_advantage.py index b27e2c9ab0..871bda7500 100644 --- a/trinity/algorithm/advantage_fn/opmd_advantage.py +++ b/trinity/algorithm/advantage_fn/opmd_advantage.py @@ -61,9 +61,11 @@ def __call__( # TODO: consider id2baseline[idx] = id2score[idx] (so that this sample won't take effect?) elif len(id2score[idx]) > 1: if opmd_baseline == "mean": - id2baseline[idx] = torch.mean(torch.tensor(id2score[idx])) + id2baseline[idx] = torch.mean( + torch.tensor(id2score[idx], dtype=torch.float32) + ) elif opmd_baseline == "logavgexp": - rewards_tensor = torch.tensor(id2score[idx]) + rewards_tensor = torch.tensor(id2score[idx], dtype=torch.float32) # here we use the fact that logavgexp(x) = logsumexp(x) - log(len(x)) id2baseline[idx] = tau * ( torch.logsumexp(rewards_tensor / tau, dim=-1) diff --git a/trinity/algorithm/advantage_fn/rloo_advantage.py b/trinity/algorithm/advantage_fn/rloo_advantage.py index fb2680a68b..5cc079e687 100644 --- a/trinity/algorithm/advantage_fn/rloo_advantage.py +++ b/trinity/algorithm/advantage_fn/rloo_advantage.py @@ -50,7 +50,7 @@ def __call__( if len(id2score[idx]) == 1: id2mean[idx] = torch.tensor(0.0) elif len(id2score[idx]) > 1: - id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) + id2mean[idx] = torch.mean(torch.tensor(id2score[idx], dtype=torch.float32)) else: raise ValueError(f"no score in prompt index: {idx}") for i in range(bsz): diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index ceabdc771a..e7f0e5c775 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -37,8 +37,8 @@ class Task(dict): raw_task: Optional[dict] = None # The raw data sample # automatically assigned ids - batch_id: int = 0 - task_id: int = 0 + batch_id: Union[int, str] = 0 + task_id: Union[int, str] = 0 def to_workflow( self, model: Any, auxiliary_models: Optional[List[openai.OpenAI]] = None diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index 2c3cb0c5fb..7b64fb2eaf 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -5,7 +5,7 @@ import time import traceback from collections import defaultdict, deque -from dataclasses import dataclass +from dataclasses import dataclass, replace from typing import Dict, List, Optional, Tuple, Union import ray @@ -24,7 +24,6 @@ class TaskWrapper: task: Task batch_id: Union[int, str] - repeat_times: int class RunnerWrapper: @@ -75,7 +74,6 @@ async def run_with_retry(self, task: TaskWrapper) -> Tuple[Status, List, int]: try: for attempt in range(self.retry_times + 1): try: - task.task.rollout_args.n = task.repeat_times status, exps = await asyncio.wait_for( self.runner.run_task.remote(task.task), self.timeout ) @@ -297,25 +295,27 @@ def schedule(self, tasks: List[Task], batch_id: Union[int, str]) -> None: def _split_and_submit_tasks(self, tasks: List[Task], batch_id: Union[int, str]) -> None: for i, task in enumerate(tasks): - task.batch_id = batch_id - task.task_id = i if self.max_repeat_times is None: self.pending_tasks[batch_id].appendleft( TaskWrapper( - task=task, + task=replace(task, batch_id=batch_id, task_id=i), batch_id=batch_id, - repeat_times=task.rollout_args.n, ) ) continue rest_repeat_times = task.rollout_args.n while rest_repeat_times > 0: + repeat_times = min(self.max_repeat_times, rest_repeat_times) task_wrapper = TaskWrapper( - task=task, + task=replace( + task, + batch_id=batch_id, + task_id=i, + rollout_args=replace(task.rollout_args, n=repeat_times), + ), batch_id=batch_id, - repeat_times=min(self.max_repeat_times, rest_repeat_times), ) - rest_repeat_times -= task_wrapper.repeat_times + rest_repeat_times -= repeat_times self.pending_tasks[batch_id].appendleft(task_wrapper) async def get_results(