From d8d116d01b24b405a65e052198ad389bdaeb2517 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 9 Jan 2024 10:14:24 +0100 Subject: [PATCH] Revert "Address issue #1122 (#1174)" This reverts commit d57d0f9ca46a63d370b91791352edda0154576f5. --- trl/trainer/sft_trainer.py | 6 +----- trl/trainer/utils.py | 1 - 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 8eed33b241..9c06b102ff 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -418,11 +418,7 @@ def tokenize(element): else: self._dataset_sanity_checked = True - return { - "input_ids": outputs["input_ids"], - "labels": outputs["input_ids"], - "attention_mask": outputs["attention_mask"], - } + return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]} tokenized_dataset = dataset.map( tokenize, diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 172b607b57..5c646c3876 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -459,7 +459,6 @@ def __iter__(self): yield { "input_ids": torch.LongTensor(example), "labels": torch.LongTensor(example), - "attention_mask": torch.ones(len(example)), }