From 6def24b3ab949f65e1a9aec88daf6f0190d9e748 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 30 Jul 2025 14:30:22 +0800 Subject: [PATCH 1/7] fix test reporter --- .github/workflows/unittest.yaml | 6 +++--- trinity/algorithm/add_strategy/__init__.py | 4 ++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index 45b6d57e3f..41f7e226ab 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -66,13 +66,13 @@ jobs: run: | TYPE="${{ steps.test_type.outputs.type }}" if [ "$TYPE" = "all" ]; then - docker compose exec trinity-node-1 pytest tests -v -s --ignore=tests/data --ctrf report.json echo "tests_run=true" >> $GITHUB_ENV + docker compose exec trinity-node-1 pytest tests -v -s --ignore=tests/data --ctrf report.json elif [ "$TYPE" = "diff" ]; then if [ -s ../../../test_dirs.txt ]; then + echo "tests_run=true" >> $GITHUB_ENV TEST_DIRS=$(cat ../../../test_dirs.txt | xargs) docker compose exec trinity-node-1 pytest $TEST_DIRS -v -s --ignore=tests/data --ctrf report.json - echo "tests_run=true" >> $GITHUB_ENV else echo "No changed modules detected, skipping tests." echo "tests_run=false" >> $GITHUB_ENV @@ -80,8 +80,8 @@ jobs: elif [ "$TYPE" = "module" ]; then MODULE="${{ steps.test_type.outputs.module }}" if [ -n "$MODULE" ]; then - docker compose exec trinity-node-1 pytest tests/$MODULE -v -s --ignore=tests/data --ctrf report.json echo "tests_run=true" >> $GITHUB_ENV + docker compose exec trinity-node-1 pytest tests/$MODULE -v -s --ignore=tests/data --ctrf report.json else echo "No module specified, skipping tests." echo "tests_run=false" >> $GITHUB_ENV diff --git a/trinity/algorithm/add_strategy/__init__.py b/trinity/algorithm/add_strategy/__init__.py index d1bbc84e1c..bed728c11b 100644 --- a/trinity/algorithm/add_strategy/__init__.py +++ b/trinity/algorithm/add_strategy/__init__.py @@ -1,6 +1,8 @@ from trinity.algorithm.add_strategy.add_strategy import ( ADD_STRATEGY, AddStrategy, + GRPOAddStrategy, + OPMDAddStrategy, RewardVarianceAddStrategy, ) @@ -8,4 +10,6 @@ "ADD_STRATEGY", "AddStrategy", "RewardVarianceAddStrategy", + "GRPOAddStrategy", + "OPMDAddStrategy", ] From 014ac8dae6c735b65dbfc1cfbca8a638989197e9 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 30 Jul 2025 20:12:12 +0800 Subject: [PATCH 2/7] add step_wise_grpo --- trinity/algorithm/add_strategy/__init__.py | 3 + .../add_strategy/step_wise_add_strategy.py | 125 ++++++++++++++++++ 2 files changed, 128 insertions(+) create mode 100644 trinity/algorithm/add_strategy/step_wise_add_strategy.py diff --git a/trinity/algorithm/add_strategy/__init__.py b/trinity/algorithm/add_strategy/__init__.py index bed728c11b..1bb547dc57 100644 --- a/trinity/algorithm/add_strategy/__init__.py +++ b/trinity/algorithm/add_strategy/__init__.py @@ -6,10 +6,13 @@ RewardVarianceAddStrategy, ) +from trinity.algorithm.add_strategy.step_wise_add_strategy import StepWiseGRPOStrategy + __all__ = [ "ADD_STRATEGY", "AddStrategy", "RewardVarianceAddStrategy", "GRPOAddStrategy", "OPMDAddStrategy", + "StepWiseGRPOStrategy", ] diff --git a/trinity/algorithm/add_strategy/step_wise_add_strategy.py b/trinity/algorithm/add_strategy/step_wise_add_strategy.py new file mode 100644 index 0000000000..6557c41a2a --- /dev/null +++ b/trinity/algorithm/add_strategy/step_wise_add_strategy.py @@ -0,0 +1,125 @@ +import asyncio +from typing import Dict, List, Tuple + +import torch + +from trinity.algorithm.add_strategy.add_strategy import ( + ADD_STRATEGY, + AddStrategy, + group_by, +) +from trinity.buffer import BufferWriter +from trinity.common.experience import Experience +from trinity.utils.monitor import gather_metrics + + +@ADD_STRATEGY.register_module("step_wise_grpo") +class StepWiseGRPOStrategy(AddStrategy): + """ + An example AddStrategy that broadcasts advantages from the last step to previous steps. + Inspired by rLLM (https://github.com/rllm-org/rllm). + """ + + def __init__( + self, + writer: BufferWriter, + epsilon: float = 1e-6, + enable_step_norm: bool = False, + **kwargs, + ) -> None: + super().__init__(writer) + self.epsilon = epsilon + self.enable_step_norm = enable_step_norm + + def calculate_group_advantage( + self, exps: Dict[str, Experience] + ) -> Tuple[Dict[str, float], Dict[str, float]]: + """Calculate group advantage for a given group of experiences. + + Args: + exps (Dict[str, Experience]): One experience per run, keyed by run ID. + + Returns: + Dict[str, float]: A tuple containing the scores for each run. + Dict[str, float]: Metrics for logging. + """ + print(f"Calculating group advantage for {len(exps)} experiences") + with torch.no_grad(): + if len(exps) == 1: + group_reward_mean = torch.tensor(0.0) + group_reward_std = torch.tensor(1.0) + else: + rewards = torch.tensor([exp.reward for exp in exps.values()]) + group_reward_mean = torch.mean(rewards) + group_reward_std = torch.std(rewards) + scores = {} + for rid, exp in exps.items(): + score = (exp.reward - group_reward_mean) / (group_reward_std + self.epsilon) + scores[rid] = score + metrics = { + "group_reward_mean": group_reward_mean.item(), + "group_reward_std": group_reward_std.item(), + } + return scores, metrics + + def broadcast_advantages( + self, run_exps: Dict[str, List[Experience]], scores: Dict[str, float] + ) -> Dict[str, List[Experience]]: + """Broadcast the calculated advantages to all previous steps in each run. + + Args: + run_exps (Dict[str, List[Experience]]): Experiences grouped by run ID. + scores (Dict[str, float]): Calculated scores for each run. + + Returns: + Dict[str, List[Experience]]: Updated experiences with advantages broadcasted. + """ + for run_id, exps in run_exps.items(): + score = scores[run_id] + traj_length = len(exps) + for exp in exps: + exp.advantages = exp.action_mask * score # type: ignore [operator] + if self.enable_step_norm: + exp.advantages /= traj_length + exp.returns = exp.advantages.clone() + return run_exps + + async def add(self, exps: List[Experience], step: int) -> Tuple[int, Dict]: + if len(exps) == 0: + return 0, {} + print(f"adding {len(exps)} experiences") + cnt = 0 + tasks = [] + metric_list = [] + # Step 1: split the experiences into sub-groups by task + task_exps = group_by(exps, "task") + # Step 2: further split each task's experiences into sub-groups by run + run_exps = {} + for task_exp in task_exps.values(): + run_exps = group_by(task_exp, "run") + + # Step3: extract the last experience (last step) from each run and calculate scores + last_step_exps = {run_id: step_exps[-1] for run_id, step_exps in run_exps.items()} + scores, metrics = self.calculate_group_advantage(last_step_exps) + metric_list.append(metrics) + + # Step 4: broadcast the advantages to all previous steps + run_exps = self.broadcast_advantages(run_exps, scores) + for exps in run_exps.values(): + cnt += len(exps) + tasks.append(self.writer.write_async(exps)) + + if tasks: + await asyncio.gather(*tasks) + try: + metrics = gather_metrics(metric_list, "group_advantages") + except ValueError: + metrics = {} # empty metric list causes ValueError, ignore it + return cnt, metrics + + def default_args(self) -> Dict: + """Return the default configuration for this strategy.""" + return { + "epsilon": 1e-6, + "enable_step_norm": False, + } \ No newline at end of file From 3d3198ecec724d2c9532e9f45a3dc149762b05ea Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 31 Jul 2025 13:51:07 +0800 Subject: [PATCH 3/7] fix comments --- trinity/algorithm/add_strategy/__init__.py | 1 - .../algorithm/add_strategy/step_wise_add_strategy.py | 10 ++++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/trinity/algorithm/add_strategy/__init__.py b/trinity/algorithm/add_strategy/__init__.py index 1bb547dc57..cb06982d24 100644 --- a/trinity/algorithm/add_strategy/__init__.py +++ b/trinity/algorithm/add_strategy/__init__.py @@ -5,7 +5,6 @@ OPMDAddStrategy, RewardVarianceAddStrategy, ) - from trinity.algorithm.add_strategy.step_wise_add_strategy import StepWiseGRPOStrategy __all__ = [ diff --git a/trinity/algorithm/add_strategy/step_wise_add_strategy.py b/trinity/algorithm/add_strategy/step_wise_add_strategy.py index 6557c41a2a..25af30d1b9 100644 --- a/trinity/algorithm/add_strategy/step_wise_add_strategy.py +++ b/trinity/algorithm/add_strategy/step_wise_add_strategy.py @@ -43,7 +43,6 @@ def calculate_group_advantage( Dict[str, float]: A tuple containing the scores for each run. Dict[str, float]: Metrics for logging. """ - print(f"Calculating group advantage for {len(exps)} experiences") with torch.no_grad(): if len(exps) == 1: group_reward_mean = torch.tensor(0.0) @@ -55,7 +54,7 @@ def calculate_group_advantage( scores = {} for rid, exp in exps.items(): score = (exp.reward - group_reward_mean) / (group_reward_std + self.epsilon) - scores[rid] = score + scores[rid] = score.item() metrics = { "group_reward_mean": group_reward_mean.item(), "group_reward_std": group_reward_std.item(), @@ -87,14 +86,12 @@ def broadcast_advantages( async def add(self, exps: List[Experience], step: int) -> Tuple[int, Dict]: if len(exps) == 0: return 0, {} - print(f"adding {len(exps)} experiences") cnt = 0 tasks = [] metric_list = [] # Step 1: split the experiences into sub-groups by task task_exps = group_by(exps, "task") # Step 2: further split each task's experiences into sub-groups by run - run_exps = {} for task_exp in task_exps.values(): run_exps = group_by(task_exp, "run") @@ -117,9 +114,10 @@ async def add(self, exps: List[Experience], step: int) -> Tuple[int, Dict]: metrics = {} # empty metric list causes ValueError, ignore it return cnt, metrics - def default_args(self) -> Dict: + @classmethod + def default_args(cls) -> Dict: """Return the default configuration for this strategy.""" return { "epsilon": 1e-6, "enable_step_norm": False, - } \ No newline at end of file + } From c8cd1c591afd0b2a11418f27102d873d41e9d091 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 31 Jul 2025 19:54:19 +0800 Subject: [PATCH 4/7] fix torch dtype --- trinity/algorithm/add_strategy/add_strategy.py | 4 ++-- trinity/algorithm/add_strategy/step_wise_add_strategy.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/trinity/algorithm/add_strategy/add_strategy.py b/trinity/algorithm/add_strategy/add_strategy.py index cdff571ca0..c9083ea2ec 100644 --- a/trinity/algorithm/add_strategy/add_strategy.py +++ b/trinity/algorithm/add_strategy/add_strategy.py @@ -110,7 +110,7 @@ def calculate_group_advantage( group_reward_mean = torch.tensor(0.0) group_reward_std = torch.tensor(1.0) else: - rewards = torch.tensor([exp.reward for exp in exps]) + rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32) group_reward_mean = torch.mean(rewards) group_reward_std = torch.std(rewards) for exp in exps: @@ -155,7 +155,7 @@ def calculate_group_advantage( if len(exps) == 1: group_baseline = torch.tensor(0.0) else: - group_rewards = torch.tensor([exp.reward for exp in exps]) + group_rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32) if self.opmd_baseline == "mean": group_baseline = torch.mean(group_rewards) else: diff --git a/trinity/algorithm/add_strategy/step_wise_add_strategy.py b/trinity/algorithm/add_strategy/step_wise_add_strategy.py index 25af30d1b9..17dc6d8c12 100644 --- a/trinity/algorithm/add_strategy/step_wise_add_strategy.py +++ b/trinity/algorithm/add_strategy/step_wise_add_strategy.py @@ -48,7 +48,7 @@ def calculate_group_advantage( group_reward_mean = torch.tensor(0.0) group_reward_std = torch.tensor(1.0) else: - rewards = torch.tensor([exp.reward for exp in exps.values()]) + rewards = torch.tensor([exp.reward for exp in exps.values()], dtype=torch.float32) group_reward_mean = torch.mean(rewards) group_reward_std = torch.std(rewards) scores = {} From fe7d2121e61342cac174a528fa9a383410d1dbd2 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 31 Jul 2025 20:35:40 +0800 Subject: [PATCH 5/7] add tests for add_strategy --- tests/algorithm/add_strategy_test.py | 62 ++++++++++++++++++++++++++++ trinity/buffer/ray_wrapper.py | 4 +- trinity/common/experience.py | 12 +++++- 3 files changed, 76 insertions(+), 2 deletions(-) create mode 100644 tests/algorithm/add_strategy_test.py diff --git a/tests/algorithm/add_strategy_test.py b/tests/algorithm/add_strategy_test.py new file mode 100644 index 0000000000..2812320256 --- /dev/null +++ b/tests/algorithm/add_strategy_test.py @@ -0,0 +1,62 @@ +import unittest +from unittest.mock import AsyncMock, MagicMock + +import torch + +from trinity.algorithm import ADD_STRATEGY +from trinity.common.experience import EID, Experience + + +class TestAddStrategy(unittest.IsolatedAsyncioTestCase): + async def test_grpo_args(self): + writer = MagicMock() + writer.write_async = AsyncMock() + strategy = ADD_STRATEGY.get("grpo")(writer, epsilon=1e-7) + self.assertEqual(strategy.epsilon, 1e-7) + task_num = 3 + repeat_times = 5 + exps = [ + Experience( + eid=EID( + batch=0, + task=j, + run=i, + ), + tokens=torch.zeros(5), + prompt_length=2, + reward=i, + ) + for i in range(repeat_times) + for j in range(task_num) + ] + count, metrics = await strategy.add(exps, step=0) + self.assertEqual(count, task_num * repeat_times) + self.assertIn("group_advantages/reward_mean/mean", metrics) + self.assertIn("group_advantages/reward_std/mean", metrics) + self.assertTrue(metrics["group_advantages/reward_mean/mean"] == 2.0) + self.assertTrue( + metrics["group_advantages/reward_std/mean"] + == torch.std(torch.tensor([i for i in range(repeat_times)], dtype=torch.float32)).item() + ) + + repeat_times = 1 + exps = [ + Experience( + eid=EID( + batch=0, + task=j, + run=i, + ), + tokens=torch.zeros(5), + prompt_length=2, + reward=i, + ) + for i in range(repeat_times) + for j in range(task_num) + ] + count, metrics = await strategy.add(exps, step=0) + self.assertEqual(count, task_num * repeat_times) + self.assertIn("group_advantages/reward_mean/mean", metrics) + self.assertIn("group_advantages/reward_std/mean", metrics) + self.assertTrue(metrics["group_advantages/reward_mean/mean"] == 0.0) + self.assertTrue(metrics["group_advantages/reward_std/mean"] == 1.0) diff --git a/trinity/buffer/ray_wrapper.py b/trinity/buffer/ray_wrapper.py index 7fe4fb0bf1..4222a023b1 100644 --- a/trinity/buffer/ray_wrapper.py +++ b/trinity/buffer/ray_wrapper.py @@ -18,7 +18,7 @@ from trinity.buffer.utils import default_storage_path, retry_session from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import ReadStrategy, StorageType -from trinity.common.experience import Experience +from trinity.common.experience import EID, Experience from trinity.common.workflows import Task from trinity.utils.log import get_logger @@ -141,6 +141,8 @@ def default(self, o): return o.to_dict() if isinstance(o, Task): return o.to_dict() + if isinstance(o, EID): + return o.to_dict() return super().default(o) diff --git a/trinity/common/experience.py b/trinity/common/experience.py index ddea73bba7..c78daecdd7 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -13,7 +13,7 @@ @dataclass -class EID(dict): +class EID: """Experience ID class to uniquely identify an experience. To enable the full functionality of the experience grouping, user should manually set the `run` and `step` fields in custom workflows. @@ -71,6 +71,16 @@ def __str__(self): def __repr__(self): return f"EID(batch={self.batch}, task={self.task}, run={self.run}, step={self.step}, uuid={self.suffix})" + def to_dict(self) -> dict: + """Convert the EID to a dictionary.""" + return { + "batch": self.batch, + "task": self.task, + "run": self.run, + "step": self.step, + "suffix": self.suffix, + } + class ExperienceType(Enum): """Enum for experience types.""" From 290ac3e8ed26c5b1e7fbf06fc821037488810729 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 31 Jul 2025 20:46:16 +0800 Subject: [PATCH 6/7] add more tests --- tests/algorithm/add_strategy_test.py | 68 +++++++++++++++++++ .../algorithm/add_strategy/add_strategy.py | 2 +- .../add_strategy/step_wise_add_strategy.py | 4 +- 3 files changed, 71 insertions(+), 3 deletions(-) diff --git a/tests/algorithm/add_strategy_test.py b/tests/algorithm/add_strategy_test.py index 2812320256..6b080c0dc8 100644 --- a/tests/algorithm/add_strategy_test.py +++ b/tests/algorithm/add_strategy_test.py @@ -8,6 +8,7 @@ class TestAddStrategy(unittest.IsolatedAsyncioTestCase): + async def test_grpo_args(self): writer = MagicMock() writer.write_async = AsyncMock() @@ -38,6 +39,8 @@ async def test_grpo_args(self): metrics["group_advantages/reward_std/mean"] == torch.std(torch.tensor([i for i in range(repeat_times)], dtype=torch.float32)).item() ) + write_async_call_count_1 = writer.write_async.call_count + self.assertEqual(write_async_call_count_1, 3) repeat_times = 1 exps = [ @@ -60,3 +63,68 @@ async def test_grpo_args(self): self.assertIn("group_advantages/reward_std/mean", metrics) self.assertTrue(metrics["group_advantages/reward_mean/mean"] == 0.0) self.assertTrue(metrics["group_advantages/reward_std/mean"] == 1.0) + write_async_call_count_2 = writer.write_async.call_count + self.assertTrue(write_async_call_count_2 - write_async_call_count_1 == 3) + + + async def test_reward_variance_strategy(self): + writer = MagicMock() + writer.write_async = AsyncMock() + strategy = ADD_STRATEGY.get("reward_variance")(writer, variance_threshold=0.0) + self.assertEqual(strategy.variance_threshold, 0.0) + task_num = 3 + repeat_times = 5 + exps = [ + Experience( + eid=EID( + batch=0, + task=j, + run=i, + ), + tokens=torch.zeros(5), + prompt_length=2, + reward=0.5, + ) + for i in range(repeat_times) + for j in range(task_num) + ] + count, metrics = await strategy.add(exps, step=0) + self.assertEqual(count, 0) + + write_async_call_count = writer.write_async.call_count + self.assertEqual(write_async_call_count, 0) + + + async def test_step_wise_grpo_strategy(self): + writer = MagicMock() + writer.write_async = AsyncMock() + strategy = ADD_STRATEGY.get("step_wise_grpo")(writer, epsilon=1e-7) + self.assertEqual(strategy.epsilon, 1e-7) + task_num = 2 + repeat_times = 3 + step_num = 4 + exps = [ + Experience( + eid=EID( + batch=0, + task=j, + run=i, + step=k, + ), + tokens=torch.zeros(5), + prompt_length=2, + reward=i, + ) + for k in range(step_num) + for i in range(repeat_times) + for j in range(task_num) + ] + count, metrics = await strategy.add(exps, step=0) + self.assertEqual(count, task_num * repeat_times * step_num) + self.assertIn("group_advantages/reward_mean/mean", metrics) + self.assertIn("group_advantages/reward_std/mean", metrics) + self.assertTrue(metrics["group_advantages/reward_mean/mean"] == 1.0) + self.assertTrue(metrics["group_advantages/reward_std/mean"] == torch.std(torch.tensor([i for i in range(repeat_times)], dtype=torch.float32)).item() + ) + write_async_call_count = writer.write_async.call_count + self.assertEqual(write_async_call_count, task_num * repeat_times) \ No newline at end of file diff --git a/trinity/algorithm/add_strategy/add_strategy.py b/trinity/algorithm/add_strategy/add_strategy.py index c9083ea2ec..3984d49759 100644 --- a/trinity/algorithm/add_strategy/add_strategy.py +++ b/trinity/algorithm/add_strategy/add_strategy.py @@ -178,7 +178,7 @@ def default_args(cls) -> dict: @ADD_STRATEGY.register_module("reward_variance") -class RewardVarianceAddStrategy(GRPOAddStrategy): +class RewardVarianceAddStrategy(AddStrategy): """An example AddStrategy that filters experiences based on a reward variance threshold.""" def __init__(self, writer: BufferWriter, variance_threshold: float = 0.0, **kwargs) -> None: diff --git a/trinity/algorithm/add_strategy/step_wise_add_strategy.py b/trinity/algorithm/add_strategy/step_wise_add_strategy.py index 17dc6d8c12..b7ba2a514d 100644 --- a/trinity/algorithm/add_strategy/step_wise_add_strategy.py +++ b/trinity/algorithm/add_strategy/step_wise_add_strategy.py @@ -56,8 +56,8 @@ def calculate_group_advantage( score = (exp.reward - group_reward_mean) / (group_reward_std + self.epsilon) scores[rid] = score.item() metrics = { - "group_reward_mean": group_reward_mean.item(), - "group_reward_std": group_reward_std.item(), + "reward_mean": group_reward_mean.item(), + "reward_std": group_reward_std.item(), } return scores, metrics From c09f92c8d3af791177775367d73327f7876dca65 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 31 Jul 2025 20:47:40 +0800 Subject: [PATCH 7/7] fix pre-commit --- tests/algorithm/add_strategy_test.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/algorithm/add_strategy_test.py b/tests/algorithm/add_strategy_test.py index 6b080c0dc8..171cd4e89f 100644 --- a/tests/algorithm/add_strategy_test.py +++ b/tests/algorithm/add_strategy_test.py @@ -8,7 +8,6 @@ class TestAddStrategy(unittest.IsolatedAsyncioTestCase): - async def test_grpo_args(self): writer = MagicMock() writer.write_async = AsyncMock() @@ -66,7 +65,6 @@ async def test_grpo_args(self): write_async_call_count_2 = writer.write_async.call_count self.assertTrue(write_async_call_count_2 - write_async_call_count_1 == 3) - async def test_reward_variance_strategy(self): writer = MagicMock() writer.write_async = AsyncMock() @@ -94,7 +92,6 @@ async def test_reward_variance_strategy(self): write_async_call_count = writer.write_async.call_count self.assertEqual(write_async_call_count, 0) - async def test_step_wise_grpo_strategy(self): writer = MagicMock() writer.write_async = AsyncMock() @@ -124,7 +121,9 @@ async def test_step_wise_grpo_strategy(self): self.assertIn("group_advantages/reward_mean/mean", metrics) self.assertIn("group_advantages/reward_std/mean", metrics) self.assertTrue(metrics["group_advantages/reward_mean/mean"] == 1.0) - self.assertTrue(metrics["group_advantages/reward_std/mean"] == torch.std(torch.tensor([i for i in range(repeat_times)], dtype=torch.float32)).item() + self.assertTrue( + metrics["group_advantages/reward_std/mean"] + == torch.std(torch.tensor([i for i in range(repeat_times)], dtype=torch.float32)).item() ) write_async_call_count = writer.write_async.call_count - self.assertEqual(write_async_call_count, task_num * repeat_times) \ No newline at end of file + self.assertEqual(write_async_call_count, task_num * repeat_times)