-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SFTTrainer
] Fix non packed dataset
#444
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -276,27 +276,43 @@ def _prepare_non_packed_dataloader( | |
self, tokenizer, dataset, dataset_text_field, max_seq_len, formatting_func=None | ||
): | ||
use_formatting_func = formatting_func is not None and dataset_text_field is None | ||
self._dataset_sanity_checked = False | ||
|
||
# Inspired from: https://huggingface.co/learn/nlp-course/chapter7/6?fw=pt | ||
def tokenize(element): | ||
input_batch = [] | ||
attention_masks = [] | ||
|
||
outputs = tokenizer( | ||
element[dataset_text_field] if not use_formatting_func else formatting_func(element), | ||
truncation=True, | ||
padding=True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @younesbelkada this code is still incorrect - consider the case where all samples in the dataset are less than Perhaps:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes you are correct thanks a lot for flagging, do you want to open a PR for that? happy to do it otherwise |
||
max_length=max_seq_len, | ||
return_overflowing_tokens=False, | ||
return_length=True, | ||
) | ||
input_batch = [] | ||
for length, input_ids in zip(outputs["length"], outputs["input_ids"]): | ||
|
||
if use_formatting_func and not self._dataset_sanity_checked: | ||
if not isinstance(formatting_func(element), list): | ||
raise ValueError( | ||
"The `formatting_func` should return a list of processed strings since it can lead to silent bugs." | ||
) | ||
else: | ||
self._dataset_sanity_checked = True | ||
|
||
for length, input_ids, attention_mask in zip( | ||
outputs["length"], outputs["input_ids"], outputs["attention_mask"] | ||
): | ||
if length == max_seq_len: | ||
input_batch.append(input_ids) | ||
attention_masks.append(attention_mask) | ||
|
||
if len(input_batch) == 0: | ||
# warn users | ||
warnings.warn( | ||
f"Found 0 samples with a length of {max_seq_len}. You might want to decrease the `max_seq_len` argument." | ||
) | ||
return {"input_ids": input_batch} | ||
return {"input_ids": input_batch, "attention_mask": attention_masks} | ||
|
||
tokenized_dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh interesting. So previously, we were dumping an entire dataset to the prompt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly yes :D the previous examples on the documentation were wrong and we were dumping the entire mini-batches when processing the dataset .. :/