Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/unittest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,22 @@ jobs:
run: |
TYPE="${{ steps.test_type.outputs.type }}"
if [ "$TYPE" = "all" ]; then
docker compose exec trinity-node-1 pytest tests -v -s --ignore=tests/data --ctrf report.json
echo "tests_run=true" >> $GITHUB_ENV
docker compose exec trinity-node-1 pytest tests -v -s --ignore=tests/data --ctrf report.json
elif [ "$TYPE" = "diff" ]; then
if [ -s ../../../test_dirs.txt ]; then
echo "tests_run=true" >> $GITHUB_ENV
TEST_DIRS=$(cat ../../../test_dirs.txt | xargs)
docker compose exec trinity-node-1 pytest $TEST_DIRS -v -s --ignore=tests/data --ctrf report.json
echo "tests_run=true" >> $GITHUB_ENV
else
echo "No changed modules detected, skipping tests."
echo "tests_run=false" >> $GITHUB_ENV
fi
elif [ "$TYPE" = "module" ]; then
MODULE="${{ steps.test_type.outputs.module }}"
if [ -n "$MODULE" ]; then
docker compose exec trinity-node-1 pytest tests/$MODULE -v -s --ignore=tests/data --ctrf report.json
echo "tests_run=true" >> $GITHUB_ENV
docker compose exec trinity-node-1 pytest tests/$MODULE -v -s --ignore=tests/data --ctrf report.json
else
echo "No module specified, skipping tests."
echo "tests_run=false" >> $GITHUB_ENV
Expand Down
129 changes: 129 additions & 0 deletions tests/algorithm/add_strategy_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import unittest
from unittest.mock import AsyncMock, MagicMock

import torch

from trinity.algorithm import ADD_STRATEGY
from trinity.common.experience import EID, Experience


class TestAddStrategy(unittest.IsolatedAsyncioTestCase):
async def test_grpo_args(self):
writer = MagicMock()
writer.write_async = AsyncMock()
strategy = ADD_STRATEGY.get("grpo")(writer, epsilon=1e-7)
self.assertEqual(strategy.epsilon, 1e-7)
task_num = 3
repeat_times = 5
exps = [
Experience(
eid=EID(
batch=0,
task=j,
run=i,
),
tokens=torch.zeros(5),
prompt_length=2,
reward=i,
)
for i in range(repeat_times)
for j in range(task_num)
]
count, metrics = await strategy.add(exps, step=0)
self.assertEqual(count, task_num * repeat_times)
self.assertIn("group_advantages/reward_mean/mean", metrics)
self.assertIn("group_advantages/reward_std/mean", metrics)
self.assertTrue(metrics["group_advantages/reward_mean/mean"] == 2.0)
self.assertTrue(
metrics["group_advantages/reward_std/mean"]
== torch.std(torch.tensor([i for i in range(repeat_times)], dtype=torch.float32)).item()
)
write_async_call_count_1 = writer.write_async.call_count
self.assertEqual(write_async_call_count_1, 3)

repeat_times = 1
exps = [
Experience(
eid=EID(
batch=0,
task=j,
run=i,
),
tokens=torch.zeros(5),
prompt_length=2,
reward=i,
)
for i in range(repeat_times)
for j in range(task_num)
]
count, metrics = await strategy.add(exps, step=0)
self.assertEqual(count, task_num * repeat_times)
self.assertIn("group_advantages/reward_mean/mean", metrics)
self.assertIn("group_advantages/reward_std/mean", metrics)
self.assertTrue(metrics["group_advantages/reward_mean/mean"] == 0.0)
self.assertTrue(metrics["group_advantages/reward_std/mean"] == 1.0)
write_async_call_count_2 = writer.write_async.call_count
self.assertTrue(write_async_call_count_2 - write_async_call_count_1 == 3)

async def test_reward_variance_strategy(self):
writer = MagicMock()
writer.write_async = AsyncMock()
strategy = ADD_STRATEGY.get("reward_variance")(writer, variance_threshold=0.0)
self.assertEqual(strategy.variance_threshold, 0.0)
task_num = 3
repeat_times = 5
exps = [
Experience(
eid=EID(
batch=0,
task=j,
run=i,
),
tokens=torch.zeros(5),
prompt_length=2,
reward=0.5,
)
for i in range(repeat_times)
for j in range(task_num)
]
count, metrics = await strategy.add(exps, step=0)
self.assertEqual(count, 0)

write_async_call_count = writer.write_async.call_count
self.assertEqual(write_async_call_count, 0)

async def test_step_wise_grpo_strategy(self):
writer = MagicMock()
writer.write_async = AsyncMock()
strategy = ADD_STRATEGY.get("step_wise_grpo")(writer, epsilon=1e-7)
self.assertEqual(strategy.epsilon, 1e-7)
task_num = 2
repeat_times = 3
step_num = 4
exps = [
Experience(
eid=EID(
batch=0,
task=j,
run=i,
step=k,
),
tokens=torch.zeros(5),
prompt_length=2,
reward=i,
)
for k in range(step_num)
for i in range(repeat_times)
for j in range(task_num)
]
count, metrics = await strategy.add(exps, step=0)
self.assertEqual(count, task_num * repeat_times * step_num)
self.assertIn("group_advantages/reward_mean/mean", metrics)
self.assertIn("group_advantages/reward_std/mean", metrics)
self.assertTrue(metrics["group_advantages/reward_mean/mean"] == 1.0)
self.assertTrue(
metrics["group_advantages/reward_std/mean"]
== torch.std(torch.tensor([i for i in range(repeat_times)], dtype=torch.float32)).item()
)
write_async_call_count = writer.write_async.call_count
self.assertEqual(write_async_call_count, task_num * repeat_times)
4 changes: 4 additions & 0 deletions trinity/algorithm/add_strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
OPMDAddStrategy,
RewardVarianceAddStrategy,
)
from trinity.algorithm.add_strategy.step_wise_add_strategy import StepWiseGRPOStrategy

__all__ = [
"ADD_STRATEGY",
"AddStrategy",
"GRPOAddStrategy",
"OPMDAddStrategy",
"RewardVarianceAddStrategy",
"GRPOAddStrategy",
"OPMDAddStrategy",
"StepWiseGRPOStrategy",
]
2 changes: 1 addition & 1 deletion trinity/algorithm/add_strategy/add_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def default_args(cls) -> dict:


@ADD_STRATEGY.register_module("reward_variance")
class RewardVarianceAddStrategy(GRPOAddStrategy):
class RewardVarianceAddStrategy(AddStrategy):
"""An example AddStrategy that filters experiences based on a reward variance threshold."""

def __init__(self, writer: BufferWriter, variance_threshold: float = 0.0, **kwargs) -> None:
Expand Down
123 changes: 123 additions & 0 deletions trinity/algorithm/add_strategy/step_wise_add_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import asyncio
from typing import Dict, List, Tuple

import torch

from trinity.algorithm.add_strategy.add_strategy import (
ADD_STRATEGY,
AddStrategy,
group_by,
)
from trinity.buffer import BufferWriter
from trinity.common.experience import Experience
from trinity.utils.monitor import gather_metrics


@ADD_STRATEGY.register_module("step_wise_grpo")
class StepWiseGRPOStrategy(AddStrategy):
"""
An example AddStrategy that broadcasts advantages from the last step to previous steps.
Inspired by rLLM (https://github.com/rllm-org/rllm).
"""

def __init__(
self,
writer: BufferWriter,
epsilon: float = 1e-6,
enable_step_norm: bool = False,
**kwargs,
) -> None:
super().__init__(writer)
self.epsilon = epsilon
self.enable_step_norm = enable_step_norm

def calculate_group_advantage(
self, exps: Dict[str, Experience]
) -> Tuple[Dict[str, float], Dict[str, float]]:
"""Calculate group advantage for a given group of experiences.

Args:
exps (Dict[str, Experience]): One experience per run, keyed by run ID.

Returns:
Dict[str, float]: A tuple containing the scores for each run.
Dict[str, float]: Metrics for logging.
"""
with torch.no_grad():
if len(exps) == 1:
group_reward_mean = torch.tensor(0.0)
group_reward_std = torch.tensor(1.0)
else:
rewards = torch.tensor([exp.reward for exp in exps.values()], dtype=torch.float32)
group_reward_mean = torch.mean(rewards)
group_reward_std = torch.std(rewards)
scores = {}
for rid, exp in exps.items():
score = (exp.reward - group_reward_mean) / (group_reward_std + self.epsilon)
scores[rid] = score.item()
metrics = {
"reward_mean": group_reward_mean.item(),
"reward_std": group_reward_std.item(),
}
return scores, metrics

def broadcast_advantages(
self, run_exps: Dict[str, List[Experience]], scores: Dict[str, float]
) -> Dict[str, List[Experience]]:
"""Broadcast the calculated advantages to all previous steps in each run.

Args:
run_exps (Dict[str, List[Experience]]): Experiences grouped by run ID.
scores (Dict[str, float]): Calculated scores for each run.

Returns:
Dict[str, List[Experience]]: Updated experiences with advantages broadcasted.
"""
for run_id, exps in run_exps.items():
score = scores[run_id]
traj_length = len(exps)
for exp in exps:
exp.advantages = exp.action_mask * score # type: ignore [operator]
if self.enable_step_norm:
exp.advantages /= traj_length
exp.returns = exp.advantages.clone()
return run_exps

async def add(self, exps: List[Experience], step: int) -> Tuple[int, Dict]:
if len(exps) == 0:
return 0, {}
cnt = 0
tasks = []
metric_list = []
# Step 1: split the experiences into sub-groups by task
task_exps = group_by(exps, "task")
# Step 2: further split each task's experiences into sub-groups by run
for task_exp in task_exps.values():
run_exps = group_by(task_exp, "run")

# Step3: extract the last experience (last step) from each run and calculate scores
last_step_exps = {run_id: step_exps[-1] for run_id, step_exps in run_exps.items()}
scores, metrics = self.calculate_group_advantage(last_step_exps)
metric_list.append(metrics)

# Step 4: broadcast the advantages to all previous steps
run_exps = self.broadcast_advantages(run_exps, scores)
for exps in run_exps.values():
cnt += len(exps)
tasks.append(self.writer.write_async(exps))

if tasks:
await asyncio.gather(*tasks)
try:
metrics = gather_metrics(metric_list, "group_advantages")
except ValueError:
metrics = {} # empty metric list causes ValueError, ignore it
return cnt, metrics

@classmethod
def default_args(cls) -> Dict:
"""Return the default configuration for this strategy."""
return {
"epsilon": 1e-6,
"enable_step_norm": False,
}
4 changes: 3 additions & 1 deletion trinity/buffer/ray_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from trinity.buffer.utils import default_storage_path, retry_session
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import ReadStrategy, StorageType
from trinity.common.experience import Experience
from trinity.common.experience import EID, Experience
from trinity.common.workflows import Task
from trinity.utils.log import get_logger

Expand Down Expand Up @@ -141,6 +141,8 @@ def default(self, o):
return o.to_dict()
if isinstance(o, Task):
return o.to_dict()
if isinstance(o, EID):
return o.to_dict()
return super().default(o)


Expand Down
12 changes: 11 additions & 1 deletion trinity/common/experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


@dataclass
class EID(dict):
class EID:
"""Experience ID class to uniquely identify an experience.

To enable the full functionality of the experience grouping, user should manually set the `run` and `step` fields in custom workflows.
Expand Down Expand Up @@ -71,6 +71,16 @@ def __str__(self):
def __repr__(self):
return f"EID(batch={self.batch}, task={self.task}, run={self.run}, step={self.step}, uuid={self.suffix})"

def to_dict(self) -> dict:
"""Convert the EID to a dictionary."""
return {
"batch": self.batch,
"task": self.task,
"run": self.run,
"step": self.step,
"suffix": self.suffix,
}


class ExperienceType(Enum):
"""Enum for experience types."""
Expand Down