From df0094e9a5baa963cdf524d77a2f0c5bb71979f6 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 15 Aug 2025 17:47:16 +0800 Subject: [PATCH 1/3] fix custom_fields in experience --- tests/common/experience_test.py | 17 ++++++++++++++++- trinity/common/experience.py | 2 +- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/tests/common/experience_test.py b/tests/common/experience_test.py index 212b88fb2d..6ff3ebdbb5 100644 --- a/tests/common/experience_test.py +++ b/tests/common/experience_test.py @@ -6,7 +6,7 @@ import torch from trinity.buffer.schema.sql_schema import ExperienceModel -from trinity.common.experience import EID, Experience, Experiences +from trinity.common.experience import EID, Experience, Experiences, CustomField db_url = os.path.join(os.path.dirname(__file__), "tmp", "test.db") dataset_path = os.path.join(os.path.dirname(__file__), "data") @@ -320,6 +320,21 @@ def test_dpo_experience_batch_conversion(self): ) ) + def test_gather_experiences_with_custom_fields(self): + # test multiple experiences gathering + exps = [ + Experience(tokens=torch.tensor([1, 2]), reward=0.1, prompt_length=1, info={"a": 1.0, "b": 3}), + Experience(tokens=torch.tensor([3, 4, 5]), reward=0.2, prompt_length=2, info={"a": 2, "c": 4}), + ] + batch = Experiences.gather_experiences(exps, custom_fields=[CustomField("a", "a", torch.float32)]) + self.assertEqual(batch.batch_size, 2) + self.assertEqual(batch.prompt_length, 2) + self.assertEqual(batch.tokens.shape[1], 3) + self.assertEqual(batch.rewards[0], 0.1) + self.assertEqual(batch.rewards[1], 0.2) + self.assertIn("a", batch.custom_fields) + self.assertEqual(batch.a[0], 1) + if __name__ == "__main__": unittest.main() diff --git a/trinity/common/experience.py b/trinity/common/experience.py index c78daecdd7..169dcef443 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -351,7 +351,7 @@ def split_dpo_experience_to_single_turn(experiences: List[Experience]) -> List[E return single_turn_experiences -@dataclass(frozen=True) +@dataclass class Experiences: """A container for a batch of experiences, for high performance communication usage. From b3a60b5c0ed5fcc21c263bd8e9b3a237ba322474 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 15 Aug 2025 17:48:46 +0800 Subject: [PATCH 2/3] add test --- tests/common/experience_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/common/experience_test.py b/tests/common/experience_test.py index 6ff3ebdbb5..bf90bef852 100644 --- a/tests/common/experience_test.py +++ b/tests/common/experience_test.py @@ -333,7 +333,8 @@ def test_gather_experiences_with_custom_fields(self): self.assertEqual(batch.rewards[0], 0.1) self.assertEqual(batch.rewards[1], 0.2) self.assertIn("a", batch.custom_fields) - self.assertEqual(batch.a[0], 1) + self.assertEqual(batch.a[0], 1.0) + self.assertEqual(batch.a[1], 2.0) if __name__ == "__main__": From a72f345d78e114cd7f396a4b07c24e518444ef11 Mon Sep 17 00:00:00 2001 From: pxc Date: Fri, 15 Aug 2025 17:49:54 +0800 Subject: [PATCH 3/3] fix pre-commit --- tests/common/experience_test.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/common/experience_test.py b/tests/common/experience_test.py index bf90bef852..a7ca37e699 100644 --- a/tests/common/experience_test.py +++ b/tests/common/experience_test.py @@ -6,7 +6,7 @@ import torch from trinity.buffer.schema.sql_schema import ExperienceModel -from trinity.common.experience import EID, Experience, Experiences, CustomField +from trinity.common.experience import EID, CustomField, Experience, Experiences db_url = os.path.join(os.path.dirname(__file__), "tmp", "test.db") dataset_path = os.path.join(os.path.dirname(__file__), "data") @@ -323,10 +323,16 @@ def test_dpo_experience_batch_conversion(self): def test_gather_experiences_with_custom_fields(self): # test multiple experiences gathering exps = [ - Experience(tokens=torch.tensor([1, 2]), reward=0.1, prompt_length=1, info={"a": 1.0, "b": 3}), - Experience(tokens=torch.tensor([3, 4, 5]), reward=0.2, prompt_length=2, info={"a": 2, "c": 4}), + Experience( + tokens=torch.tensor([1, 2]), reward=0.1, prompt_length=1, info={"a": 1.0, "b": 3} + ), + Experience( + tokens=torch.tensor([3, 4, 5]), reward=0.2, prompt_length=2, info={"a": 2, "c": 4} + ), ] - batch = Experiences.gather_experiences(exps, custom_fields=[CustomField("a", "a", torch.float32)]) + batch = Experiences.gather_experiences( + exps, custom_fields=[CustomField("a", "a", torch.float32)] + ) self.assertEqual(batch.batch_size, 2) self.assertEqual(batch.prompt_length, 2) self.assertEqual(batch.tokens.shape[1], 3)