diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index bcedf2e86e..a1d47a48ff 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -199,6 +199,7 @@ async def test_generate( self.assertTrue( torch.equal(result_dict["assistant_masks"][0][prompt_length:], exp.action_mask) ) + self.assertTrue(exp.logprobs.shape[0] == exp.tokens.shape[0] - prompt_length) self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens)) self.assertRaises(ValueError, self.model_wrapper.get_openai_client) if self.config.explorer.rollout_model.enable_history: diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 7cedf86687..5d6ff40bd2 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -172,7 +172,8 @@ async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]: return experiences async def logprobs(self, token_ids: List[int]) -> torch.Tensor: - """Calculate the logprobs of the given tokens in async. + """Calculate the logprobs of the given tokens in async. Please slice the result carefully + to align with the actual response length. Args: token_ids (List[int]): The input token ids (seq_length). @@ -217,11 +218,11 @@ async def convert_messages_to_experience(self, messages: List[dict]) -> Experien self.chat_template = self.tokenizer.get_chat_template() token_ids, action_mask, prompt_length = self.action_mask_method( self.tokenizer, messages, self.chat_template - ) - logprobs = await self.logprobs(token_ids=token_ids.tolist()) + ) # (seq_length, ), (seq_length, ) + logprobs = await self.logprobs(token_ids=token_ids.tolist()) # (seq_length - 1,) return Experience( tokens=token_ids, - logprobs=logprobs, + logprobs=logprobs[prompt_length - 1 :], prompt_length=prompt_length, action_mask=action_mask[prompt_length:], # Exclude the prompt tokens )