diff --git a/README.md b/README.md index 19fbcc46fa..ce504e4cf5 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,7 @@ Trinity-RFT provides functionalities for users with different backgrounds and ob ## 🚀 News +* [2026-01] 🎉 Three papers accepted by ICLR 2026: [CHORD](https://github.com/agentscope-ai/Trinity-RFT/tree/main/examples/mix_chord), [BOTS](https://github.com/agentscope-ai/Trinity-RFT/tree/main/examples/bots), and [Group-relative REINFORCE variants](https://github.com/agentscope-ai/Trinity-RFT/tree/main/examples/rec_gsm8k). Try out these new algorithms in Trinity-RFT! * [2026-01] [[Release Notes]](https://github.com/agentscope-ai/Trinity-RFT/releases/tag/v0.4.1) Trinity-RFT v0.4.1 released: upgraded verl to v0.7.0, Tinker backend supports OpenAI API, bug fixes. * [2026-01] Introducing [R3L](https://github.com/shiweijiezero/R3L): a systematic reflect-then-retry RL mechanism with efficient language-guided exploration and stable off-policy learning ([paper](https://arxiv.org/abs/2601.03715)). * [2025-12] [[Release Notes]](https://github.com/agentscope-ai/Trinity-RFT/releases/tag/v0.4.0) Trinity-RFT v0.4.0 released: added [Tinker](https://thinkingmachines.ai/tinker/) backend for users **without GPUs**, add more benchmarks, enhance online RL and more. diff --git a/README_zh.md b/README_zh.md index 9348f6c907..128c9a8187 100644 --- a/README_zh.md +++ b/README_zh.md @@ -41,6 +41,7 @@ Trinity-RFT 面向不同背景和目标的用户提供相应功能: ## 🚀 新闻 +* [2026-01] 🎉 三篇论文被 ICLR 2026 接收:[CHORD](https://github.com/agentscope-ai/Trinity-RFT/tree/main/examples/mix_chord)、[BOTS](https://github.com/agentscope-ai/Trinity-RFT/tree/main/examples/bots) 和 [Group-relative REINFORCE 系列变种](https://github.com/agentscope-ai/Trinity-RFT/tree/main/examples/rec_gsm8k)。在 Trinity-RFT 中尝试这些新算法吧! * [2026-01] [[发布说明]](https://github.com/agentscope-ai/Trinity-RFT/releases/tag/v0.4.1) Trinity-RFT v0.4.1 发布:升级 verl 至 v0.7.0,Tinker 后端支持 OpenAI API,修复若干 Bug。 * [2026-01] 推出 [R3L](https://github.com/shiweijiezero/R3L):基于反思-重试的强化学习机制,由自然语言反馈引导高效探索,并达成稳定的 off-policy 学习([论文](https://arxiv.org/abs/2601.03715))。 * [2025-12] [[发布说明]](https://github.com/agentscope-ai/Trinity-RFT/releases/tag/v0.4.0) Trinity-RFT v0.4.0 发布:新增[Tinker](https://thinkingmachines.ai/tinker/) 后端以支持在 **无 GPU** 的设备上训练,增加更多基准测试,增强在线 RL 等功能。 diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 9ffe3e7c3c..d238653946 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -353,6 +353,80 @@ async def test_model_len(self): ) +class TestMessageProcess(RayUnittestBaseAsync): + def setUp(self): + self.config = get_template_config() + self.config.mode = "explore" + self.config.model.model_path = get_model_path() + self.config.model.max_model_len = 100 + self.config.model.max_prompt_tokens = 50 + self.config.model.max_response_tokens = 50 + self.config.model.enable_prompt_truncation = True + self.config.explorer.rollout_model.enable_openai_api = True + self.config.check_and_update() + self.engines, self.auxiliary_engines = create_inference_models(self.config) + self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True) + + async def test_truncation_status(self): + """Test truncation status for multi-turn conversations.""" + await prepare_engines(self.engines, self.auxiliary_engines) + await self.model_wrapper.prepare() + + # Case: "prompt_truncated" + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "A very long prompt." * 20}, + {"role": "assistant", "content": "OK"}, + ] + converted_experience = self.model_wrapper.convert_messages_to_experience( + messages, + ) + self._check_experience(converted_experience, "prompt_truncated") + + # Case: No truncation + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Tell me about weather."}, + {"role": "assistant", "content": "OK"}, + ] + converted_experience = self.model_wrapper.convert_messages_to_experience( + messages, + ) + self._check_experience(converted_experience, None) + + async def test_no_prompt_truncation(self): + """Test truncation status for multi-turn conversations in workflow.""" + self.config.model.enable_prompt_truncation = False + self.config.check_and_update() + await prepare_engines(self.engines, self.auxiliary_engines) + await self.model_wrapper.prepare() + + # Case: No truncation + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Tell me about weather."}, + ] + converted_experience = self.model_wrapper.convert_messages_to_experience(messages) + self._check_experience(converted_experience, None) + + # Case: "response_truncated" + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Tell me about weather."}, + {"role": "assistant", "content": "A very long response" * 20}, + ] + converted_experience = self.model_wrapper.convert_messages_to_experience(messages) + self._check_experience(converted_experience, "response_truncated") + + def _check_experience(self, exp, target_truncate_status): + self.assertIsNotNone(exp) + model_len = len(exp.tokens) + prompt_length = exp.prompt_length + self.assertEqual(exp.truncate_status, target_truncate_status) + self.assertLessEqual(prompt_length, self.config.model.max_prompt_tokens) + self.assertLessEqual(model_len, self.config.model.max_model_len) + + class TestAPIServer(RayUnittestBaseAsync): def setUp(self): self.config = get_template_config() diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index ed982fe9b9..a139ec3a39 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -163,7 +163,13 @@ async def convert_messages_to_experience( tools: Optional[List[dict]] = None, temperature: Optional[float] = None, ) -> Experience: - """Convert a list of messages into an experience in async.""" + """Convert a list of messages into an experience in async. + + Args: + messages: List of message dictionaries + tools: Optional list of tools + temperature: Optional temperature for logprobs calculation + """ if self.tokenizer is None: await self._initialize_tokenizer() if self.chat_template is None: @@ -176,22 +182,43 @@ async def convert_messages_to_experience( enable_thinking=self.enable_thinking, ) # (seq_length, ), (seq_length, ) - # Truncate tokens if they exceed the length limit assert token_ids is not None truncate_status = None - if self.config.max_model_len is not None and self.config.max_model_len > 0: - if len(token_ids) > self.config.max_model_len - 1: - truncate_status = "response_truncated" - self.logger.warning( - f"Warning: {len(token_ids)=} exceeds the length limit {(self.config.max_model_len - 1)=}" - ) - token_ids = token_ids[: self.config.max_model_len - 1] - action_mask = action_mask[: self.config.max_model_len - 1] + # Truncate prompt if it exceeds max_prompt_tokens + if ( + self.config.enable_prompt_truncation + and self.config.max_prompt_tokens is not None + and prompt_length > self.config.max_prompt_tokens + ): + truncate_status = "prompt_truncated" + self.logger.warning( + f"Warning: {prompt_length=} exceeds the length limit {self.config.max_prompt_tokens}, " + f"this experience will be not counted in the loss computation." + ) + return Experience( + tokens=token_ids[: self.config.max_prompt_tokens + 1], + logprobs=torch.zeros(1, dtype=torch.float32), + prompt_length=self.config.max_prompt_tokens, # Use truncated length + action_mask=torch.zeros(1, dtype=torch.bool), # ignored in loss computation + messages=messages, # messages are not truncated + truncate_status=truncate_status, + ) + + # Truncate response if it exceeds max_model_len + max_model_len = self.config.max_model_len + if max_model_len is not None and len(token_ids) > max_model_len - 1: + truncate_status = "response_truncated" + self.logger.warning( + f"Warning: {len(token_ids)=} exceeds the length limit {(max_model_len - 1)=}" + ) + token_ids = token_ids[: max_model_len - 1] + action_mask = action_mask[: max_model_len - 1] temperature = temperature if temperature is not None else self.config.temperature logprobs = await self.logprobs( token_ids=token_ids.tolist(), temperature=temperature ) # (seq_length - 1,) + return Experience( tokens=token_ids, logprobs=logprobs[prompt_length - 1 :], diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 0c45466a70..eeb479cbc3 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -223,7 +223,10 @@ def process_messages_to_experience( ) -> Experience: converted_experience = self.model.convert_messages_to_experience(messages) return self._build_experience_from_converted( - converted_experience, reward, info, truncate_status + converted_experience, + reward, + info, + converted_experience.truncate_status or truncate_status, ) async def process_messages_to_experience_async( @@ -231,7 +234,10 @@ async def process_messages_to_experience_async( ) -> Experience: converted_experience = await self.model.convert_messages_to_experience_async(messages) return self._build_experience_from_converted( - converted_experience, reward, info, truncate_status + converted_experience, + reward, + info, + converted_experience.truncate_status or truncate_status, )