Skip to content

Commit

Permalink
fix: improve error message when pad_token_id is not configured (#1152)
Browse files Browse the repository at this point in the history
* fix: improve error message when `pad_token_id` is not configured

* Add test for error raised when pad_token is None

* Fix pre-commit errors

* Fix error in the test environment
  • Loading branch information
yumemio authored Jan 17, 2024
1 parent 97b9fa2 commit 341f6a6
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
35 changes: 35 additions & 0 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,41 @@ def test_dpo_trainer_without_providing_ref_model_with_lora(self):
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))

def test_dpo_trainer_padding_token_is_none(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
evaluation_strategy="steps",
)

dummy_dataset = self._init_dummy_dataset()

tokenizer = AutoTokenizer.from_pretrained(self.model_id)
tokenizer.pad_token = None

with self.assertRaisesRegex(
ValueError,
expected_regex=r"Padding is enabled, but the tokenizer is not configured with a padding token."
r" Explicitly set `tokenizer.pad_token` \(e.g. `tokenizer.pad_token = tokenizer.eos_token`\)"
r" before calling the trainer.",
):
trainer = DPOTrainer(
model=self.model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)

trainer.train()

@require_no_wandb
def test_dpo_trainer_generate_during_eval_no_wandb(self):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down
12 changes: 12 additions & 0 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,12 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
to_pad = [torch.LongTensor(ex[k]) for ex in features]

if (k.startswith("prompt")) and (k.endswith("input_ids")):
if self.pad_token_id is None:
raise ValueError(
"Padding is enabled, but the tokenizer is not configured with a padding token."
" Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)"
" before calling the trainer."
)
padding_value = self.pad_token_id
elif k.endswith("_attention_mask"):
padding_value = 0
Expand All @@ -319,6 +325,12 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
else:
to_pad = [torch.LongTensor(ex[k]) for ex in features]
if k.endswith("_input_ids"):
if self.pad_token_id is None:
raise ValueError(
"Padding is enabled, but the tokenizer is not configured with a padding token."
" Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)"
" before calling the trainer."
)
padding_value = self.pad_token_id
elif k.endswith("_labels"):
padding_value = self.label_pad_token_id
Expand Down

0 comments on commit 341f6a6

Please sign in to comment.