Skip to content

Commit

Permalink
Fix case conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
lewtun committed Sep 4, 2023
1 parent e6e346d commit 5b67e0d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
21 changes: 10 additions & 11 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class RewardTrainer(Trainer):
def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
args: RewardConfig = None,
args: Optional[RewardConfig] = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
Expand Down Expand Up @@ -98,7 +98,12 @@ def __init__(
peft_config (`Dict`, defaults to `None`):
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
"""
if max_length is not None:
print(args)
if max_length is not None and args.max_length is not None:
raise ValueError(
"You cannot specify both `max_length` and `args.max_length`. Please use the `RewardConfig` to set `max_length` once."
)
if max_length is not None and args.max_length is None:
warnings.warn(
"The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.",
FutureWarning,
Expand All @@ -124,22 +129,16 @@ def __init__(
raise ValueError(
"max_length or a tokenizer must be specified when using the default RewardDataCollatorWithPadding"
)
if max_length is None:
warnings.warn(
"When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig."
" It will be set to `512` by default, but you should do it yourself in the future.",
UserWarning,
)
max_length = 512
elif args.max_length is None:
if max_length is None and args.max_length is None:
warnings.warn(
"When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig."
" It will be set to `512` by default, but you should do it yourself in the future.",
UserWarning,
)
max_length = 512
else:
if max_length is None and args.max_length is not None:
max_length = args.max_length

data_collator = RewardDataCollatorWithPadding(tokenizer, max_length=max_length)

if args.remove_unused_columns:
Expand Down
3 changes: 1 addition & 2 deletions trl/trainer/training_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
@dataclass
class RewardConfig(TrainingArguments):
"""
RewardConfig is the subset of the arguments we use in our example scripts **which relate to the training loop
itself**.
RewardConfig collects all training arguments related to the [`RewardTrainer`] class.
Using [`HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
Expand Down

0 comments on commit 5b67e0d

Please sign in to comment.