Skip to content

Commit

Permalink
Add whiten ops before compute advatanges (#887)
Browse files Browse the repository at this point in the history
* Add whiten ops before compute advatanges

1. From LLaMA 2 paper, it says:
```
We also find it important to whiten the final linear scores (shown here by reversing the sigmoid with the logit function) in order to increase stability and balance properly with the KL penalty term (β) above.
```
2. This function is taken from [alpaca_farm](https://github.com/tatsu-lab/alpaca_farm/blob/64e489c67ea502ab5fa944bebde3078c9722f6ee/src/alpaca_farm/rl/ppo_trainer.py#L86)

* Fix type def of self

---------

Co-authored-by: Lin Junpeng <linjunpeng@sensetime.com>
  • Loading branch information
SingL3 and Lin Junpeng authored Oct 23, 2023
1 parent 304ee70 commit 1f3314f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
2 changes: 2 additions & 0 deletions trl/trainer/ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ class PPOConfig:
"""Use score normalization. Only applicable if use_score_scaling is True"""
score_clip: Optional[float] = None
"""Score clipping"""
whiten_rewards: bool = False
"""Whiten the rewards before compute advantages"""

# computed hyperparameters at runtime; we use `tyro.conf.Suppress` to hide them from the help text
is_encoder_decoder: Optional[tyro.conf.Suppress[bool]] = None
Expand Down
5 changes: 4 additions & 1 deletion trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,7 +1095,7 @@ def _kl_penalty(self, logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor
raise NotImplementedError

def compute_advantages(
self: torch.FloatTensor,
self,
values: torch.FloatTensor,
rewards: torch.FloatTensor,
mask: torch.FloatTensor,
Expand All @@ -1107,6 +1107,9 @@ def compute_advantages(
values = values * mask
rewards = rewards * mask

if self.config.whiten_rewards:
rewards = masked_whiten(rewards, mask, shift_mean=False)

for t in reversed(range(gen_len)):
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
delta = rewards[:, t] + self.config.gamma * nextvalues - values[:, t]
Expand Down

0 comments on commit 1f3314f

Please sign in to comment.