Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix logits/chosen and logits/rejected metrics in kto_trainer. #2077

Merged
merged 3 commits into from
Sep 18, 2024

Conversation

PhilipMay
Copy link
Contributor

The calculation of the logits/chosen and logits/rejected metrics in kto_trainer seem to be wrong. A nansum() followed by nanmean() applied to the policy_rejected_logits is wrong.

Our fix is to apply nansum() followed by an other nansum() and then devide the result by count/chosen or count/rejected.

@PhilipMay
Copy link
Contributor Author

PhilipMay commented Sep 18, 2024

Tagging @MAOJIASONG as the original author and @claralp and asking for a review. 🙏🏼

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec
Copy link
Member

qgallouedec commented Sep 18, 2024

Now I understand:

Let's take the following example:

              process 1   process 2
logits        [0, 1, 2]   [3, 4]
sum               3         7
gather              [3, 7]
mean                  5

Your proposed fix

              process 1   process 2
logits        [0, 1, 2]   [3, 4]
sum               3         7
gather              [3, 7]
sum                   10
/all_num_chosen        2

Would the following work?

              process 1   process 2
logits        [0, 1, 2]   [3, 4]
gather          [0, 1, 2, 3, 4]
sum                   10
/all_num_chosen        2

Probably, but too memory intensive, right?

@qgallouedec qgallouedec merged commit 0d2bee5 into huggingface:main Sep 18, 2024
9 checks passed
@MAOJIASONG
Copy link
Contributor

Now I understand:

Let's take the following example:

              process 1   process 2
logits        [0, 1, 2]   [3, 4]
sum               3         7
gather              [3, 7]
mean                  5

Your proposed fix

              process 1   process 2
logits        [0, 1, 2]   [3, 4]
sum               3         7
gather              [3, 7]
sum                   10
/all_num_chosen        2

Would the following work?

              process 1   process 2
logits        [0, 1, 2]   [3, 4]
gather          [0, 1, 2, 3, 4]
sum                   10
/all_num_chosen        2

Probably, but too memory intensive, right?

May I ask what's the difference between the first one, and the proposed second one? why is it better to use all_num_chosen for division? @qgallouedec

@qgallouedec
Copy link
Member

Because otherwise the result is wrong (5)

@PhilipMay
Copy link
Contributor Author

Probably, but too memory intensive, right?

@qgallouedec I am not sure about this to be honest.

@MAOJIASONG
Copy link
Contributor

MAOJIASONG commented Sep 19, 2024

Because otherwise the result is wrong (5)

Thx, I misinterpreted the example. Thanks for pointing it out.

@PhilipMay PhilipMay changed the title [WIP] Fix logits/chosen and logits/rejected metrics in kto_trainer. Fix logits/chosen and logits/rejected metrics in kto_trainer. Sep 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants