Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions tests/explorer/scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,23 @@ async def test_split_tasks(self):

await scheduler.stop()

async def test_multi_step_execution(self):
self.config.explorer.max_repeat_times_per_runner = 1
self.config.check_and_update()
scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()])
await scheduler.start()
tasks = generate_tasks(2, repeat_times=4)

n_steps = 3
for i in range(1, n_steps + 1):
scheduler.schedule(tasks, batch_id=i)
statuses, exps = await scheduler.get_results(batch_id=i)
self.assertEqual(len(statuses), 2 * 4)
exps = self.queue.read(batch_size=2 * 4)
self.assertEqual(len(exps), 2 * 4)

await scheduler.stop()

def tearDown(self):
try:
ray.shutdown()
Expand Down
4 changes: 2 additions & 2 deletions trinity/algorithm/add_strategy/add_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def calculate_group_advantage(
group_reward_mean = torch.tensor(0.0)
group_reward_std = torch.tensor(1.0)
else:
rewards = torch.tensor([exp.reward for exp in exps])
rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32)
group_reward_mean = torch.mean(rewards)
group_reward_std = torch.std(rewards)
for exp in exps:
Expand Down Expand Up @@ -155,7 +155,7 @@ def calculate_group_advantage(
if len(exps) == 1:
group_baseline = torch.tensor(0.0)
else:
group_rewards = torch.tensor([exp.reward for exp in exps])
group_rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32)
if self.opmd_baseline == "mean":
group_baseline = torch.mean(group_rewards)
else:
Expand Down
4 changes: 2 additions & 2 deletions trinity/algorithm/advantage_fn/grpo_advantage.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def __call__(
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
id2std[idx] = torch.std(torch.tensor([id2score[idx]]))
id2mean[idx] = torch.mean(torch.tensor(id2score[idx], dtype=torch.float32))
id2std[idx] = torch.std(torch.tensor(id2score[idx], dtype=torch.float32))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
Expand Down
6 changes: 4 additions & 2 deletions trinity/algorithm/advantage_fn/opmd_advantage.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ def __call__(
# TODO: consider id2baseline[idx] = id2score[idx] (so that this sample won't take effect?)
elif len(id2score[idx]) > 1:
if opmd_baseline == "mean":
id2baseline[idx] = torch.mean(torch.tensor(id2score[idx]))
id2baseline[idx] = torch.mean(
torch.tensor(id2score[idx], dtype=torch.float32)
)
elif opmd_baseline == "logavgexp":
rewards_tensor = torch.tensor(id2score[idx])
rewards_tensor = torch.tensor(id2score[idx], dtype=torch.float32)
# here we use the fact that logavgexp(x) = logsumexp(x) - log(len(x))
id2baseline[idx] = tau * (
torch.logsumexp(rewards_tensor / tau, dim=-1)
Expand Down
2 changes: 1 addition & 1 deletion trinity/algorithm/advantage_fn/rloo_advantage.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __call__(
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
elif len(id2score[idx]) > 1:
id2mean[idx] = torch.mean(torch.tensor(id2score[idx]))
id2mean[idx] = torch.mean(torch.tensor(id2score[idx], dtype=torch.float32))
else:
raise ValueError(f"no score in prompt index: {idx}")
for i in range(bsz):
Expand Down
4 changes: 2 additions & 2 deletions trinity/common/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ class Task(dict):
raw_task: Optional[dict] = None # The raw data sample

# automatically assigned ids
batch_id: int = 0
task_id: int = 0
batch_id: Union[int, str] = 0
task_id: Union[int, str] = 0

def to_workflow(
self, model: Any, auxiliary_models: Optional[List[openai.OpenAI]] = None
Expand Down
20 changes: 10 additions & 10 deletions trinity/explorer/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
import traceback
from collections import defaultdict, deque
from dataclasses import dataclass
from dataclasses import dataclass, replace
from typing import Dict, List, Optional, Tuple, Union

import ray
Expand All @@ -24,7 +24,6 @@ class TaskWrapper:

task: Task
batch_id: Union[int, str]
repeat_times: int


class RunnerWrapper:
Expand Down Expand Up @@ -75,7 +74,6 @@ async def run_with_retry(self, task: TaskWrapper) -> Tuple[Status, List, int]:
try:
for attempt in range(self.retry_times + 1):
try:
task.task.rollout_args.n = task.repeat_times
status, exps = await asyncio.wait_for(
self.runner.run_task.remote(task.task), self.timeout
)
Expand Down Expand Up @@ -297,25 +295,27 @@ def schedule(self, tasks: List[Task], batch_id: Union[int, str]) -> None:

def _split_and_submit_tasks(self, tasks: List[Task], batch_id: Union[int, str]) -> None:
for i, task in enumerate(tasks):
task.batch_id = batch_id
task.task_id = i
if self.max_repeat_times is None:
self.pending_tasks[batch_id].appendleft(
TaskWrapper(
task=task,
task=replace(task, batch_id=batch_id, task_id=i),
batch_id=batch_id,
repeat_times=task.rollout_args.n,
)
)
continue
rest_repeat_times = task.rollout_args.n
while rest_repeat_times > 0:
repeat_times = min(self.max_repeat_times, rest_repeat_times)
task_wrapper = TaskWrapper(
task=task,
task=replace(
task,
batch_id=batch_id,
task_id=i,
rollout_args=replace(task.rollout_args, n=repeat_times),
),
batch_id=batch_id,
repeat_times=min(self.max_repeat_times, rest_repeat_times),
)
rest_repeat_times -= task_wrapper.repeat_times
rest_repeat_times -= repeat_times
self.pending_tasks[batch_id].appendleft(task_wrapper)

async def get_results(
Expand Down