diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index 45b6d57e3f..41f7e226ab 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -66,13 +66,13 @@ jobs: run: | TYPE="${{ steps.test_type.outputs.type }}" if [ "$TYPE" = "all" ]; then - docker compose exec trinity-node-1 pytest tests -v -s --ignore=tests/data --ctrf report.json echo "tests_run=true" >> $GITHUB_ENV + docker compose exec trinity-node-1 pytest tests -v -s --ignore=tests/data --ctrf report.json elif [ "$TYPE" = "diff" ]; then if [ -s ../../../test_dirs.txt ]; then + echo "tests_run=true" >> $GITHUB_ENV TEST_DIRS=$(cat ../../../test_dirs.txt | xargs) docker compose exec trinity-node-1 pytest $TEST_DIRS -v -s --ignore=tests/data --ctrf report.json - echo "tests_run=true" >> $GITHUB_ENV else echo "No changed modules detected, skipping tests." echo "tests_run=false" >> $GITHUB_ENV @@ -80,8 +80,8 @@ jobs: elif [ "$TYPE" = "module" ]; then MODULE="${{ steps.test_type.outputs.module }}" if [ -n "$MODULE" ]; then - docker compose exec trinity-node-1 pytest tests/$MODULE -v -s --ignore=tests/data --ctrf report.json echo "tests_run=true" >> $GITHUB_ENV + docker compose exec trinity-node-1 pytest tests/$MODULE -v -s --ignore=tests/data --ctrf report.json else echo "No module specified, skipping tests." echo "tests_run=false" >> $GITHUB_ENV diff --git a/tests/algorithm/add_strategy_test.py b/tests/algorithm/add_strategy_test.py new file mode 100644 index 0000000000..171cd4e89f --- /dev/null +++ b/tests/algorithm/add_strategy_test.py @@ -0,0 +1,129 @@ +import unittest +from unittest.mock import AsyncMock, MagicMock + +import torch + +from trinity.algorithm import ADD_STRATEGY +from trinity.common.experience import EID, Experience + + +class TestAddStrategy(unittest.IsolatedAsyncioTestCase): + async def test_grpo_args(self): + writer = MagicMock() + writer.write_async = AsyncMock() + strategy = ADD_STRATEGY.get("grpo")(writer, epsilon=1e-7) + self.assertEqual(strategy.epsilon, 1e-7) + task_num = 3 + repeat_times = 5 + exps = [ + Experience( + eid=EID( + batch=0, + task=j, + run=i, + ), + tokens=torch.zeros(5), + prompt_length=2, + reward=i, + ) + for i in range(repeat_times) + for j in range(task_num) + ] + count, metrics = await strategy.add(exps, step=0) + self.assertEqual(count, task_num * repeat_times) + self.assertIn("group_advantages/reward_mean/mean", metrics) + self.assertIn("group_advantages/reward_std/mean", metrics) + self.assertTrue(metrics["group_advantages/reward_mean/mean"] == 2.0) + self.assertTrue( + metrics["group_advantages/reward_std/mean"] + == torch.std(torch.tensor([i for i in range(repeat_times)], dtype=torch.float32)).item() + ) + write_async_call_count_1 = writer.write_async.call_count + self.assertEqual(write_async_call_count_1, 3) + + repeat_times = 1 + exps = [ + Experience( + eid=EID( + batch=0, + task=j, + run=i, + ), + tokens=torch.zeros(5), + prompt_length=2, + reward=i, + ) + for i in range(repeat_times) + for j in range(task_num) + ] + count, metrics = await strategy.add(exps, step=0) + self.assertEqual(count, task_num * repeat_times) + self.assertIn("group_advantages/reward_mean/mean", metrics) + self.assertIn("group_advantages/reward_std/mean", metrics) + self.assertTrue(metrics["group_advantages/reward_mean/mean"] == 0.0) + self.assertTrue(metrics["group_advantages/reward_std/mean"] == 1.0) + write_async_call_count_2 = writer.write_async.call_count + self.assertTrue(write_async_call_count_2 - write_async_call_count_1 == 3) + + async def test_reward_variance_strategy(self): + writer = MagicMock() + writer.write_async = AsyncMock() + strategy = ADD_STRATEGY.get("reward_variance")(writer, variance_threshold=0.0) + self.assertEqual(strategy.variance_threshold, 0.0) + task_num = 3 + repeat_times = 5 + exps = [ + Experience( + eid=EID( + batch=0, + task=j, + run=i, + ), + tokens=torch.zeros(5), + prompt_length=2, + reward=0.5, + ) + for i in range(repeat_times) + for j in range(task_num) + ] + count, metrics = await strategy.add(exps, step=0) + self.assertEqual(count, 0) + + write_async_call_count = writer.write_async.call_count + self.assertEqual(write_async_call_count, 0) + + async def test_step_wise_grpo_strategy(self): + writer = MagicMock() + writer.write_async = AsyncMock() + strategy = ADD_STRATEGY.get("step_wise_grpo")(writer, epsilon=1e-7) + self.assertEqual(strategy.epsilon, 1e-7) + task_num = 2 + repeat_times = 3 + step_num = 4 + exps = [ + Experience( + eid=EID( + batch=0, + task=j, + run=i, + step=k, + ), + tokens=torch.zeros(5), + prompt_length=2, + reward=i, + ) + for k in range(step_num) + for i in range(repeat_times) + for j in range(task_num) + ] + count, metrics = await strategy.add(exps, step=0) + self.assertEqual(count, task_num * repeat_times * step_num) + self.assertIn("group_advantages/reward_mean/mean", metrics) + self.assertIn("group_advantages/reward_std/mean", metrics) + self.assertTrue(metrics["group_advantages/reward_mean/mean"] == 1.0) + self.assertTrue( + metrics["group_advantages/reward_std/mean"] + == torch.std(torch.tensor([i for i in range(repeat_times)], dtype=torch.float32)).item() + ) + write_async_call_count = writer.write_async.call_count + self.assertEqual(write_async_call_count, task_num * repeat_times) diff --git a/trinity/algorithm/add_strategy/__init__.py b/trinity/algorithm/add_strategy/__init__.py index 51df3ed540..2ac7d8e569 100644 --- a/trinity/algorithm/add_strategy/__init__.py +++ b/trinity/algorithm/add_strategy/__init__.py @@ -5,6 +5,7 @@ OPMDAddStrategy, RewardVarianceAddStrategy, ) +from trinity.algorithm.add_strategy.step_wise_add_strategy import StepWiseGRPOStrategy __all__ = [ "ADD_STRATEGY", @@ -12,4 +13,7 @@ "GRPOAddStrategy", "OPMDAddStrategy", "RewardVarianceAddStrategy", + "GRPOAddStrategy", + "OPMDAddStrategy", + "StepWiseGRPOStrategy", ] diff --git a/trinity/algorithm/add_strategy/add_strategy.py b/trinity/algorithm/add_strategy/add_strategy.py index c9083ea2ec..3984d49759 100644 --- a/trinity/algorithm/add_strategy/add_strategy.py +++ b/trinity/algorithm/add_strategy/add_strategy.py @@ -178,7 +178,7 @@ def default_args(cls) -> dict: @ADD_STRATEGY.register_module("reward_variance") -class RewardVarianceAddStrategy(GRPOAddStrategy): +class RewardVarianceAddStrategy(AddStrategy): """An example AddStrategy that filters experiences based on a reward variance threshold.""" def __init__(self, writer: BufferWriter, variance_threshold: float = 0.0, **kwargs) -> None: diff --git a/trinity/algorithm/add_strategy/step_wise_add_strategy.py b/trinity/algorithm/add_strategy/step_wise_add_strategy.py new file mode 100644 index 0000000000..b7ba2a514d --- /dev/null +++ b/trinity/algorithm/add_strategy/step_wise_add_strategy.py @@ -0,0 +1,123 @@ +import asyncio +from typing import Dict, List, Tuple + +import torch + +from trinity.algorithm.add_strategy.add_strategy import ( + ADD_STRATEGY, + AddStrategy, + group_by, +) +from trinity.buffer import BufferWriter +from trinity.common.experience import Experience +from trinity.utils.monitor import gather_metrics + + +@ADD_STRATEGY.register_module("step_wise_grpo") +class StepWiseGRPOStrategy(AddStrategy): + """ + An example AddStrategy that broadcasts advantages from the last step to previous steps. + Inspired by rLLM (https://github.com/rllm-org/rllm). + """ + + def __init__( + self, + writer: BufferWriter, + epsilon: float = 1e-6, + enable_step_norm: bool = False, + **kwargs, + ) -> None: + super().__init__(writer) + self.epsilon = epsilon + self.enable_step_norm = enable_step_norm + + def calculate_group_advantage( + self, exps: Dict[str, Experience] + ) -> Tuple[Dict[str, float], Dict[str, float]]: + """Calculate group advantage for a given group of experiences. + + Args: + exps (Dict[str, Experience]): One experience per run, keyed by run ID. + + Returns: + Dict[str, float]: A tuple containing the scores for each run. + Dict[str, float]: Metrics for logging. + """ + with torch.no_grad(): + if len(exps) == 1: + group_reward_mean = torch.tensor(0.0) + group_reward_std = torch.tensor(1.0) + else: + rewards = torch.tensor([exp.reward for exp in exps.values()], dtype=torch.float32) + group_reward_mean = torch.mean(rewards) + group_reward_std = torch.std(rewards) + scores = {} + for rid, exp in exps.items(): + score = (exp.reward - group_reward_mean) / (group_reward_std + self.epsilon) + scores[rid] = score.item() + metrics = { + "reward_mean": group_reward_mean.item(), + "reward_std": group_reward_std.item(), + } + return scores, metrics + + def broadcast_advantages( + self, run_exps: Dict[str, List[Experience]], scores: Dict[str, float] + ) -> Dict[str, List[Experience]]: + """Broadcast the calculated advantages to all previous steps in each run. + + Args: + run_exps (Dict[str, List[Experience]]): Experiences grouped by run ID. + scores (Dict[str, float]): Calculated scores for each run. + + Returns: + Dict[str, List[Experience]]: Updated experiences with advantages broadcasted. + """ + for run_id, exps in run_exps.items(): + score = scores[run_id] + traj_length = len(exps) + for exp in exps: + exp.advantages = exp.action_mask * score # type: ignore [operator] + if self.enable_step_norm: + exp.advantages /= traj_length + exp.returns = exp.advantages.clone() + return run_exps + + async def add(self, exps: List[Experience], step: int) -> Tuple[int, Dict]: + if len(exps) == 0: + return 0, {} + cnt = 0 + tasks = [] + metric_list = [] + # Step 1: split the experiences into sub-groups by task + task_exps = group_by(exps, "task") + # Step 2: further split each task's experiences into sub-groups by run + for task_exp in task_exps.values(): + run_exps = group_by(task_exp, "run") + + # Step3: extract the last experience (last step) from each run and calculate scores + last_step_exps = {run_id: step_exps[-1] for run_id, step_exps in run_exps.items()} + scores, metrics = self.calculate_group_advantage(last_step_exps) + metric_list.append(metrics) + + # Step 4: broadcast the advantages to all previous steps + run_exps = self.broadcast_advantages(run_exps, scores) + for exps in run_exps.values(): + cnt += len(exps) + tasks.append(self.writer.write_async(exps)) + + if tasks: + await asyncio.gather(*tasks) + try: + metrics = gather_metrics(metric_list, "group_advantages") + except ValueError: + metrics = {} # empty metric list causes ValueError, ignore it + return cnt, metrics + + @classmethod + def default_args(cls) -> Dict: + """Return the default configuration for this strategy.""" + return { + "epsilon": 1e-6, + "enable_step_norm": False, + } diff --git a/trinity/buffer/ray_wrapper.py b/trinity/buffer/ray_wrapper.py index 7fe4fb0bf1..4222a023b1 100644 --- a/trinity/buffer/ray_wrapper.py +++ b/trinity/buffer/ray_wrapper.py @@ -18,7 +18,7 @@ from trinity.buffer.utils import default_storage_path, retry_session from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import ReadStrategy, StorageType -from trinity.common.experience import Experience +from trinity.common.experience import EID, Experience from trinity.common.workflows import Task from trinity.utils.log import get_logger @@ -141,6 +141,8 @@ def default(self, o): return o.to_dict() if isinstance(o, Task): return o.to_dict() + if isinstance(o, EID): + return o.to_dict() return super().default(o) diff --git a/trinity/common/experience.py b/trinity/common/experience.py index ddea73bba7..c78daecdd7 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -13,7 +13,7 @@ @dataclass -class EID(dict): +class EID: """Experience ID class to uniquely identify an experience. To enable the full functionality of the experience grouping, user should manually set the `run` and `step` fields in custom workflows. @@ -71,6 +71,16 @@ def __str__(self): def __repr__(self): return f"EID(batch={self.batch}, task={self.task}, run={self.run}, step={self.step}, uuid={self.suffix})" + def to_dict(self) -> dict: + """Convert the EID to a dictionary.""" + return { + "batch": self.batch, + "task": self.task, + "run": self.run, + "step": self.step, + "suffix": self.suffix, + } + class ExperienceType(Enum): """Enum for experience types."""