Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor RewardTrainer hyperparameters into dedicated dataclass #726

Merged
merged 10 commits into from
Sep 5, 2023

Conversation

lewtun
Copy link
Member

@lewtun lewtun commented Sep 1, 2023

This PR migrates the max_length arg of the RewardTrainer into a dedicated RewardTrainingArguments class that can also be used for storing future hyperparameters.

To be backwards compatible, I've left the variable in the trainer's init, with a warning that this will be removed in some future version.

Tested with:

accelerate launch --multi_gpu --num_processes 2 examples/scripts/reward_trainer.py --batch_size 1

@@ -94,18 +95,21 @@ def __init__(
The optimizer and scheduler to use for training.
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
The function to use to preprocess the logits before computing the metrics.
max_length (`int`, defaults to `None`):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed the transformers convention to remove deprecated args from public doc strings

peft_config (`Dict`, defaults to `None`):
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
"""
if max_length is not None:
warnings.warn(
"The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardTrainingArguments` to set `max_length` instead.",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we specify a precise version, e.g. 0.9.0?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No let's just add a warning like it is :)

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 1, 2023

The documentation is not available anymore as the PR was closed or merged.



@dataclass
class RewardTrainingArguments(TrainingArguments):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another name for this could be RewardConfig to be more aligned with PPOConfig

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No strong opinion, whatever you prefer.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opted for RewardConfig in e6e346d

I've also placed this in a training_configs.py module - let me know if you'd prefer a dedicated module per config (currently PPOConfig lives in ppo_config.py)

@lewtun
Copy link
Member Author

lewtun commented Sep 4, 2023

I think this is good to go, so gently pinging @lvwerra @younesbelkada for a final review 🙏

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @lewtun ! This looks great to me, I just have one suggestion about the name of the class, what RewardTrainingArguments ? I am also happy with the current naming if @lvwerra agrees

Comment on lines 125 to 142
"When using RewardDataCollatorWithPadding, you should set `max_length` in the RewardTrainer's init"
" it will be set to `512` by default, but you should do it yourself in the future.",
"When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig."
" It will be set to `512` by default, but you should do it yourself in the future.",
UserWarning,
)
max_length = 512
elif args.max_length is None:
warnings.warn(
"When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig."
" It will be set to `512` by default, but you should do it yourself in the future.",
UserWarning,
)
max_length = 512
else:
max_length = args.max_length
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in the case where max_length is not None then it is still overwritten when args.max_length is None, right? I think in that case we should keep the original value to be backwards compatible.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, good catch - I'll fix that

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 5b67e0d

@lewtun
Copy link
Member Author

lewtun commented Sep 4, 2023

Thanks a lot @lewtun ! This looks great to me, I just have one suggestion about the name of the class, what RewardTrainingArguments ? I am also happy with the current naming if @lvwerra agrees

Are you suggesting we use RewardTrainingArguments instead of RewardConfig? I originally had the former, but switched to the latter to be aligned with PPOConfig. I don't have a strong opinion, but whatever choice we make should stay consistent for all the other trainers, e.g. SFTTrainingArguments vs SFTConfig etc.

One advantage of XConfig classes is that it's less keystrokes :D

Copy link
Member

@lvwerra lvwerra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! 🚀

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @lewtun , makes sense, let's merge it! 🚀

@lewtun lewtun merged commit d484dc2 into main Sep 5, 2023
@lewtun lewtun deleted the wrap-rm-args branch September 5, 2023 07:05
kushal-tri pushed a commit to kushalarora/trl that referenced this pull request Sep 19, 2023
…ingface#726)

* Refactor RewardTrainer hyperparameters into dedicated dataclass

* Revert

* Add doc string

* Fix warning

* Handle backwards compat

* Fix tests

* Add docs

* Refactor to RewardConfig

* Fix case conditions

* Fix
lapp0 pushed a commit to lapp0/trl that referenced this pull request May 10, 2024
…ingface#726)

* Refactor RewardTrainer hyperparameters into dedicated dataclass

* Revert

* Add doc string

* Fix warning

* Handle backwards compat

* Fix tests

* Add docs

* Refactor to RewardConfig

* Fix case conditions

* Fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants