Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SFTTrainer] Fix non packed dataset #444

Merged
merged 3 commits into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions docs/source/sft_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,11 @@ Let us assume your dataset has two fields, `question` and `answer`. Therefore yo
```python
...
def formatting_prompts_func(example):
text = f"### Question: {example['question']}\n ### Answer: {example['answer']}"
return text
output_texts = []
for i in range(len(example['question'])):
text = f"### Question: {example['question'][i]}\n ### Answer: {example['answer'][i]}"
output_texts.append(text)
return output_texts
Comment on lines +68 to +72
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh interesting. So previously, we were dumping an entire dataset to the prompt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly yes :D the previous examples on the documentation were wrong and we were dumping the entire mini-batches when processing the dataset .. :/


trainer = SFTTrainer(
model,
Expand All @@ -76,6 +79,7 @@ trainer = SFTTrainer(

trainer.train()
```
To preperly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example on how to use SFTTrainer on alpaca dataset [here](https://github.com/lvwerra/trl/pull/444#issue-1760952763)

### Packing dataset ([`ConstantLengthDataset`])

Expand Down
29 changes: 20 additions & 9 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ def formatting_prompts_func(example):
return text


def formatting_prompts_func_batched(example):
output_text = []
for i, question in enumerate(example["question"]):
text = f"### Question: {question}\n ### Answer: {example['answer'][i]}"
output_text.append(text)
return output_text


if is_peft_available():
from peft import LoraConfig, PeftModel

Expand Down Expand Up @@ -170,12 +178,22 @@ def test_sft_trainer_uncorrect_data(self):
packing=True,
)

# This should work as well
# This should not work as well
with self.assertRaises(ValueError):
_ = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_dataset,
formatting_func=formatting_prompts_func,
packing=False,
)

# but this shpuld work
_ = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_dataset,
formatting_func=formatting_prompts_func,
formatting_func=formatting_prompts_func_batched,
packing=False,
)

Expand Down Expand Up @@ -350,13 +368,6 @@ def test_sft_trainer_with_model(self):
per_device_train_batch_size=2,
)

def formatting_prompts_func_batched(example):
output_text = []
for i, question in enumerate(example["question"]):
text = f"### Question: {question}\n ### Answer: {example['answer'][i]}"
output_text.append(text)
return output_text

trainer = SFTTrainer(
model=self.model,
args=training_args,
Expand Down
22 changes: 19 additions & 3 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,27 +276,43 @@ def _prepare_non_packed_dataloader(
self, tokenizer, dataset, dataset_text_field, max_seq_len, formatting_func=None
):
use_formatting_func = formatting_func is not None and dataset_text_field is None
self._dataset_sanity_checked = False

# Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt
def tokenize(element):
input_batch = []
attention_masks = []

outputs = tokenizer(
element[dataset_text_field] if not use_formatting_func else formatting_func(element),
truncation=True,
padding=True,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@younesbelkada this code is still incorrect - consider the case where all samples in the dataset are less than max_seq_len. Each batch will be padded to the largest element in the batch, but no data will pass the if length == max_seq_len check below.

Perhaps:

                padding='max_length',

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you are correct thanks a lot for flagging, do you want to open a PR for that? happy to do it otherwise

max_length=max_seq_len,
return_overflowing_tokens=False,
return_length=True,
)
input_batch = []
for length, input_ids in zip(outputs["length"], outputs["input_ids"]):

if use_formatting_func and not self._dataset_sanity_checked:
if not isinstance(formatting_func(element), list):
raise ValueError(
"The `formatting_func` should return a list of processed strings since it can lead to silent bugs."
)
else:
self._dataset_sanity_checked = True

for length, input_ids, attention_mask in zip(
outputs["length"], outputs["input_ids"], outputs["attention_mask"]
):
if length == max_seq_len:
input_batch.append(input_ids)
attention_masks.append(attention_mask)

if len(input_batch) == 0:
# warn users
warnings.warn(
f"Found 0 samples with a length of {max_seq_len}. You might want to decrease the `max_seq_len` argument."
)
return {"input_ids": input_batch}
return {"input_ids": input_batch, "attention_mask": attention_masks}

tokenized_dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)

Expand Down