-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add score scaling/normalization/clipping #560
Changes from 3 commits
2348147
2c573ca
f40fec0
6a79ca8
6ea53de
665aaaf
ee8213a
414f2f8
d391451
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,7 +45,6 @@ class ScriptArguments: | |
default=1, metadata={"help": "the number of gradient accumulation steps"} | ||
) | ||
early_stopping: Optional[bool] = field(default=False, metadata={"help": "whether to early stop"}) | ||
target_kl: Optional[float] = field(default=6, metadata={"help": "kl target for early stopping"}) | ||
use_peft: Optional[bool] = field(default=False, metadata={"help": "whether to use peft"}) | ||
use_seq2seq: Optional[bool] = field(default=False, metadata={"help": "whether to use seq2seq models"}) | ||
kl_penalty: Optional[str] = field( | ||
|
@@ -56,6 +55,9 @@ class ScriptArguments: | |
) | ||
target_kl: Optional[float] = field(default=0.1, metadata={"help": "kl target for early stopping"}) | ||
seed: Optional[int] = field(default=0, metadata={"help": "the random seed"}) | ||
use_score_scaling: Optional[bool] = field(default=False, metadata={"help": "Use score scaling"}) | ||
use_score_norm: Optional[bool] = field(default=False, metadata={"help": "Use score normalization"}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should clarify that this only works if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
score_clip: Optional[float] = field(default=None, metadata={"help": "Score clipping"}) | ||
|
||
|
||
parser = HfArgumentParser(ScriptArguments) | ||
|
@@ -72,6 +74,9 @@ class ScriptArguments: | |
target_kl=script_args.target_kl, | ||
kl_penalty=script_args.kl_penalty, | ||
seed=script_args.seed, | ||
use_score_scaling=script_args.use_score_scaling, | ||
use_score_norm=script_args.use_score_norm, | ||
score_clip=script_args.score_clip, | ||
) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This field seems to have been removed by mistake?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Younes,
You will find that
target_kl
already exists on L57 with a much smaller value.I dug deeper and found that
PPOConfig
has two configstarget
andtarget_kl
, wheretarget
has a default value of 6. So I assume the first duplicatetarget_kl
config here was meant to betarget
. However,target
is NOT used to populate PPOConfig at L64, so I just removed it.Regards,
Felix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great point, thank you !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is actually a bug from here: 1620da3
we overloaded the
target_kl
term - we should rename it!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @edbeeching
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lvwerra as much as I love introducing bugs into trl. I think this time it was @younesbelkada , in the Big refactor of examples and documentation (#509). Here
I agree to rename to
early_stop_kl
, or something