-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix FA2 integration #28142
Fix FA2 integration #28142
Conversation
Co-Authored-By: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the deep dive! As this is / was critical, this could be added to the Llama.md as a tip ? (nit)
Otherwise looks great. autocast
feature, was introduced in PyTorch version 1.6.0 so no worries there
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense as discussed offline ! Thanks very much for the deep dive @pacman100 !
Done. |
So FSDP is saved? |
I think so, from the experiment @pacman100 shared with me you could load a transformers model with FA-2 and train it with autocast ( |
Hello @teknium1, to re-confirm, I ran the below experiment on 8 80GB GPUs to finetune Mistral 7B for the SFT task on Ultrachat 200K (1 epoch). Code: https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/run_fsdp.sh
Observations: |
* fix fa2 * fix FA2 for popular models * improve warning and add Younes as co-author Co-Authored-By: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix the warning * Add Tip * typo fix * nit --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Hi! I haven't been able to use Llama3-8B w/FA2. I'm running the following code: from transformers import Trainer, TrainingArguments, LlamaForCausalLM
from typing import Dict
import numpy as np
from torch.utils.data import Dataset
import torch
class DummyDataset(Dataset):
def __init__(self, num_samples: int, sequence_length: int) -> None:
self.num_samples = num_samples
self.sequence_length = sequence_length
def __len__(self) -> int:
return self.num_samples
def __getitem__(self, idx: int) -> Dict[str, np.ndarray]:
x = torch.LongTensor(np.random.randint(low= 0, high= 1000, size=(self.sequence_length+ 1)))
return {"input_ids": x[:-1], "labels": x[1:]}
def main():
path_to_model = "/mloscratch/homes/solergib/models/Meta-Llama-3-8B-Instruct"
output_dir = "/mloscratch/homes/solergib/simple/output"
training_arguments = TrainingArguments(per_device_train_batch_size=1, gradient_checkpointing=True, bf16=True, max_steps=10, output_dir=output_dir)
model = LlamaForCausalLM.from_pretrained(path_to_model, attn_implementation="flash_attention_2") # It's the default
train_dataset = DummyDataset(10000000, 1024)
trainer = Trainer(model=model, args=training_arguments, train_dataset=train_dataset)
trainer.train()
if __name__ == "__main__":
main() And I get the same error complaining about the dtype:
My env:
Thanks! |
Hey! This is unrelated to |
What does this PR do?
Issues with the current FA2 integration.
torch_dtype
to thefrom_pretrained
class method mandatory. This leads to the whole model being loaded in half-precision which leads to unstable training because it would result in pure half precision training instead of mixed-precision training. Please refer Mistral loss instability #26498 (comment) for more details.Currently, main branch throws below error when not passing half precision to
torch_dtype
which shouldn't be the case.torch_dtype
, then recast the model to float32 and try to train but then end up getting error from Flash Attention library as given below:All these issues are being resolved by this PR. Notice the above graph with the before and after PR logs. With this PR, the loss is similar to the case when not using FA2.