-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor RewardTrainer hyperparameters into dedicated dataclass #726
Conversation
@@ -94,18 +95,21 @@ def __init__( | |||
The optimizer and scheduler to use for training. | |||
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): | |||
The function to use to preprocess the logits before computing the metrics. | |||
max_length (`int`, defaults to `None`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I followed the transformers
convention to remove deprecated args from public doc strings
trl/trainer/reward_trainer.py
Outdated
peft_config (`Dict`, defaults to `None`): | ||
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. | ||
""" | ||
if max_length is not None: | ||
warnings.warn( | ||
"The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardTrainingArguments` to set `max_length` instead.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we specify a precise version, e.g. 0.9.0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No let's just add a warning like it is :)
The documentation is not available anymore as the PR was closed or merged. |
trl/trainer/training_args.py
Outdated
|
||
|
||
@dataclass | ||
class RewardTrainingArguments(TrainingArguments): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another name for this could be RewardConfig
to be more aligned with PPOConfig
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No strong opinion, whatever you prefer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I opted for RewardConfig
in e6e346d
I've also placed this in a training_configs.py
module - let me know if you'd prefer a dedicated module per config (currently PPOConfig
lives in ppo_config.py
)
I think this is good to go, so gently pinging @lvwerra @younesbelkada for a final review 🙏 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trl/trainer/reward_trainer.py
Outdated
"When using RewardDataCollatorWithPadding, you should set `max_length` in the RewardTrainer's init" | ||
" it will be set to `512` by default, but you should do it yourself in the future.", | ||
"When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig." | ||
" It will be set to `512` by default, but you should do it yourself in the future.", | ||
UserWarning, | ||
) | ||
max_length = 512 | ||
elif args.max_length is None: | ||
warnings.warn( | ||
"When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig." | ||
" It will be set to `512` by default, but you should do it yourself in the future.", | ||
UserWarning, | ||
) | ||
max_length = 512 | ||
else: | ||
max_length = args.max_length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think in the case where max_length
is not None
then it is still overwritten when args.max_length is None
, right? I think in that case we should keep the original value to be backwards compatible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes, good catch - I'll fix that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 5b67e0d
Are you suggesting we use One advantage of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! 🚀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @lewtun , makes sense, let's merge it! 🚀
…ingface#726) * Refactor RewardTrainer hyperparameters into dedicated dataclass * Revert * Add doc string * Fix warning * Handle backwards compat * Fix tests * Add docs * Refactor to RewardConfig * Fix case conditions * Fix
…ingface#726) * Refactor RewardTrainer hyperparameters into dedicated dataclass * Revert * Add doc string * Fix warning * Handle backwards compat * Fix tests * Add docs * Refactor to RewardConfig * Fix case conditions * Fix
This PR migrates the
max_length
arg of theRewardTrainer
into a dedicatedRewardTrainingArguments
class that can also be used for storing future hyperparameters.To be backwards compatible, I've left the variable in the trainer's init, with a warning that this will be removed in some future version.
Tested with: