Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix reported KL in PPO trainer #1180

Merged
merged 2 commits into from
Jan 9, 2024

Conversation

mgerstgrasser
Copy link
Contributor

PPO trainer currently separately calculates the KL that's reported in the stats and to wandb, and always uses the estimated KL for this. That's fine when the estimated KL is also used for training (kl_penalty = 'kl'), but leads to a mismatch when using kl_penalty = 'abs', kl_penalty = 'mse', or kl_penalty = 'full'. Additionally, this can lead to a warning about negative KL being triggered even in cases where the KL used in training is not and cannot be negative.

This PR changes it so that instead the KL that is calculated during training is kept and used for stats and reporting.

Tested using the PPO example script in both single-GPU and Deepspeed stage 2 training, leading to identical results in the kl_penalty = 'kl'case. Verified also that kl_penalty = 'abs' now does not trigger the negative KL warning anymore.

Closes #1161.

previously this was always reporting the estimated KL, even when using `kl_penalty = 'full'` (or `abs`, etc).
Now we return the actual KL calculated in `compute_rewards()`, and report that.
@younesbelkada
Copy link
Contributor

cc @vwxyzjn @lvwerra would you be able to have a look here whenever possible 🙏

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@vwxyzjn
Copy link
Contributor

vwxyzjn commented Jan 8, 2024

I think this is good. It's better to be more consistent. Any objections @lvwerra?

@younesbelkada
Copy link
Contributor

thanks @vwxyzjn @lvwerra ! will merge then!

@younesbelkada younesbelkada merged commit a236c57 into huggingface:main Jan 9, 2024
9 checks passed
lapp0 pushed a commit to lapp0/trl that referenced this pull request May 10, 2024
* Fix reported KL in PPO trainer

previously this was always reporting the estimated KL, even when using `kl_penalty = 'full'` (or `abs`, etc).
Now we return the actual KL calculated in `compute_rewards()`, and report that.

* fix test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Negative KL warning even if KL='abs' and KL='full'
5 participants