diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index de809de2ce..553d8acb4d 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -591,7 +591,7 @@ def batched_forward_pass( if len(logprobs[j, start:end]) < 2: raise ValueError("Responses are too short. Make sure they are at least 4 tokens long.") - all_values.append(v[j, start - 1 : end - 1]) + all_values.append(v[j, start:end]) all_logprobs.append(logprobs[j, start:end]) all_ref_logprobs.append(ref_logprobs[j, start:end]) @@ -713,11 +713,11 @@ def loss( if self.is_encoder_decoder: logprob = logprobs_from_logits(logits[:, :-1, :], model_input[:, 1:]) start, end = 1, response.shape[-1] - 1 - vpred = vpred[:, start - 1 : end - 1] + vpred = vpred[:, start:end] logprob = logprob[:, start:end] else: logprob = logprobs_from_logits(logits[:, :-1, :], model_input[:, 1:]) - logprob, vpred = logprob[:, -gen_len:], vpred[:, -gen_len - 1 : -1] + logprob, vpred = logprob[:, -gen_len:], vpred[:, -gen_len:] vpredclipped = clip_by_value(vpred, values - self.config.cliprange_value, values + self.config.cliprange_value)