Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SFTTrainer] Fix non packed dataset #444

Merged
merged 3 commits into from
Jun 16, 2023
Merged

[SFTTrainer] Fix non packed dataset #444

merged 3 commits into from
Jun 16, 2023

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Jun 16, 2023

What does this PR do?

This PR properly educates users on how to correctly use formatting_func method when someone uses a non-packed dataset.
Since the dataset processing calls dataset.map(xxx, batched=True) under the hood, it is important to return an array of processed texts to properly process all texts from the dataset example batch, otherwise it will lead to silent bugs that are hard to understand such as the one described in #439

from datasets import load_dataset
from trl import SFTTrainer
import transformers

dataset = load_dataset("tatsu-lab/alpaca", split="train")

model = transformers.AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/opt-350m")

def formatting_prompts_func(examples):
    output_text = []
    for i in range(len(examples["instruction"])):
        instruction = examples["instruction"][i]
        input_text = examples["input"][i]
        response = examples["output"][i]

        if len(input_text) >= 2:
            text = f'''Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
            
            ### Instruction:
            {instruction}
            
            ### Input:
            {input_text}
            
            ### Response:
            {response}
            '''
        else:
            text = f'''Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
            
            ### Instruction:
            {instruction}
            
            ### Response:
            {response}
            '''
        output_text.append(text)

    return output_text

trainer = SFTTrainer(
    model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,
    max_seq_length=256,
    packing=False,
)

trainer.train()

The PR adds a sanity check when processing the dataset, and adds the argument padding=True, to always return a sequence of length max_seq_len and correctly appends the attention mask to the output dataset as well.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 16, 2023

The documentation is not available anymore as the PR was closed or merged.

@younesbelkada younesbelkada merged commit d1ad540 into main Jun 16, 2023
@younesbelkada younesbelkada deleted the fix-sft-dataset branch June 16, 2023 16:51
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Comment on lines +68 to +72
output_texts = []
for i in range(len(example['question'])):
text = f"### Question: {example['question'][i]}\n ### Answer: {example['answer'][i]}"
output_texts.append(text)
return output_texts
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh interesting. So previously, we were dumping an entire dataset to the prompt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly yes :D the previous examples on the documentation were wrong and we were dumping the entire mini-batches when processing the dataset .. :/

outputs = tokenizer(
element[dataset_text_field] if not use_formatting_func else formatting_func(element),
truncation=True,
padding=True,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@younesbelkada this code is still incorrect - consider the case where all samples in the dataset are less than max_seq_len. Each batch will be padded to the largest element in the batch, but no data will pass the if length == max_seq_len check below.

Perhaps:

                padding='max_length',

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you are correct thanks a lot for flagging, do you want to open a PR for that? happy to do it otherwise

@ahmadmustafaanis
Copy link

Will this fine-tune the complete model end-to-end, or will this example fine-tune just a portion of it, like in Lora?

@lvwerra
Copy link
Member

lvwerra commented Aug 2, 2023

The above example will train the full model but there are also options to use LoRA.

@hy-chen
Copy link

hy-chen commented Jan 18, 2024

Is this merged in main and then reverted? because padding is still False in _prepare_non_packed_dataloader in main (0.7.9)

@hy-chen
Copy link

hy-chen commented Jan 18, 2024

Is this merged in main and then reverted? because padding is still False in _prepare_non_packed_dataloader in main (0.7.9)

Actually padding was turned off by this PR: #512

Now running SFT on alpaca gives
ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (labelsin this case) have excessive nesting (inputs typelistwhere typeint is expected).

@haochuan-li
Copy link

Any update on this issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants