Skip to content

Commit

Permalink
Fix type checking (huggingface#748)
Browse files Browse the repository at this point in the history
  • Loading branch information
lewtun authored and Andrew Lapp committed May 10, 2024
1 parent a4cd088 commit 438e1f4
Showing 1 changed file with 36 additions and 16 deletions.
52 changes: 36 additions & 16 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch
import torch.nn as nn
from datasets import Dataset
from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer
from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_pt_utils import nested_detach
from transformers.trainer_utils import EvalPrediction
Expand Down Expand Up @@ -98,15 +98,26 @@ def __init__(
peft_config (`Dict`, defaults to `None`):
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
"""
if max_length is not None and args.max_length is not None:
raise ValueError(
"You cannot specify both `max_length` and `args.max_length`. Please use the `RewardConfig` to set `max_length` once."
)
if max_length is not None and args.max_length is None:
if type(args) == TrainingArguments:
warnings.warn(
"The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.",
"Using `transformers.TrainingArguments` for `args` is deprecated and will be removed in a future version. Please use `RewardConfig` instead.",
FutureWarning,
)
if max_length is not None:
warnings.warn(
"The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.",
FutureWarning,
)
else:
if max_length is not None and args.max_length is not None:
raise ValueError(
"You cannot specify both `max_length` and `args.max_length`. Please use the `RewardConfig` to set `max_length` once."
)
if max_length is not None and args.max_length is None:
warnings.warn(
"The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.",
FutureWarning,
)
if not is_peft_available() and peft_config is not None:
raise ValueError(
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
Expand All @@ -128,15 +139,24 @@ def __init__(
raise ValueError(
"max_length or a tokenizer must be specified when using the default RewardDataCollatorWithPadding"
)
if max_length is None and args.max_length is None:
warnings.warn(
"When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig."
" It will be set to `512` by default, but you should do it yourself in the future.",
UserWarning,
)
max_length = 512
if max_length is None and args.max_length is not None:
max_length = args.max_length
if type(args) == TrainingArguments:
if max_length is None:
warnings.warn(
"When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig."
" It will be set to `512` by default, but you should do it yourself in the future.",
UserWarning,
)
max_length = 512
else:
if max_length is None and args.max_length is None:
warnings.warn(
"When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig."
" It will be set to `512` by default, but you should do it yourself in the future.",
UserWarning,
)
max_length = 512
if max_length is None and args.max_length is not None:
max_length = args.max_length

data_collator = RewardDataCollatorWithPadding(tokenizer, max_length=max_length)

Expand Down

0 comments on commit 438e1f4

Please sign in to comment.