diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py index 96c89b7855..92aac8509f 100644 --- a/trl/trainer/ppo_config.py +++ b/trl/trainer/ppo_config.py @@ -115,6 +115,8 @@ class PPOConfig: """Use score normalization. Only applicable if use_score_scaling is True""" score_clip: Optional[float] = None """Score clipping""" + whiten_rewards: bool = False + """Whiten the rewards before compute advantages""" # computed hyperparameters at runtime; we use `tyro.conf.Suppress` to hide them from the help text is_encoder_decoder: Optional[tyro.conf.Suppress[bool]] = None diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 42a921b3ac..47b298aa37 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -1095,7 +1095,7 @@ def _kl_penalty(self, logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor raise NotImplementedError def compute_advantages( - self: torch.FloatTensor, + self, values: torch.FloatTensor, rewards: torch.FloatTensor, mask: torch.FloatTensor, @@ -1107,6 +1107,9 @@ def compute_advantages( values = values * mask rewards = rewards * mask + if self.config.whiten_rewards: + rewards = masked_whiten(rewards, mask, shift_mean=False) + for t in reversed(range(gen_len)): nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t]