Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OOM error on 7B model despite 4-bit quantization with 24GB VRAM #846

Closed
attkap opened this issue Oct 8, 2023 · 3 comments
Closed

OOM error on 7B model despite 4-bit quantization with 24GB VRAM #846

attkap opened this issue Oct 8, 2023 · 3 comments

Comments

@attkap
Copy link

attkap commented Oct 8, 2023

I've been trying to follow beginner's tutorials like:
https://www.philschmid.de/instruction-tune-llama-2
and
https://medium.com/@qendelai/fine-tuning-mistral-7b-instruct-model-in-colab-a-beginners-guide-0f7bebccf11c (and many others), but I keep getting OOM errors, despite optimising heavily for memory. What is weird is that when I follow the standard tutorials, I get OOM errors even with exactly the same setup (AWS EC2 g5.x4large) as in the tutorial.

I've now finally got a training run to work with an A10 (24GB VRAM) and the script below.
Still, this is taking up a full 16GB of VRAM and running very slowly.

I'm using a per_device_train_batch_size of 2, with 4-bit quantised LoRA on reduced target modules and a max seq length of 2048. Can anyone explain to me why this setup is still so memory inefficient?

Any advice would be deeply appreciated🙏 - I've been banging my head against this for a full day.

Code to Reproduce

import torch 
from trl import SFTTrainer
from datasets import load_dataset 
from transformers import TrainingArguments, AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM
from peft import AutoPeftModelForCausalLM, LoraConfig, get_peft_model, prepare_model_for_kbit_training

# Loading the dataset 
dataset = load_dataset("databricks/databricks-dolly-15k", split="train")

# Filter QA pairs without context
dataset = dataset.filter(lambda x:x['context'] == '')

# A prompting formatting function 
def create_prompt_instruction(sample):
    return f"""### Instruction: 
    Use the input below to create an instruction, which could have been used to generate the input using an LLM. 

    ### Input 
    {sample['response']}

    ### Response:
    {sample['instruction']}
    """
# Import model and tokenizer 
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", device_map='auto', load_in_4bit=True, use_cache=False)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# PEFT Config 
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    r=64,
    bias="none",
    task_type="CAUSAL_LM"
)

# Prepare the model for finetuning 
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)

# Split dataset into 70% for training and 30% for testing 
dataset = dataset.train_test_split(test_size=0.3)

# Use the train split for training 
train_dataset = dataset['train']

# Define training arguments 
args = TrainingArguments(
    output_dir = "mistral_instruct_qa",
    num_train_epochs = 5,
    per_device_train_batch_size = 2,
    warmup_steps = 0.03,
    logging_steps=10,
    save_strategy="epoch",
    learning_rate=2e-4,
    bf16=True,
    lr_scheduler_type='constant',
    disable_tqdm=True
)

# Define SFTTrainer arguments 
max_seq_length = 2048

trainer = SFTTrainer(
    model=model,
    peft_config=peft_config,
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    packing=True,
    formatting_func=create_prompt_instruction,
    args=args,
    train_dataset=train_dataset,
)

# Kick off the finetuning job 
trainer.train()

# Save finetuned model 
trainer.save_model("mistral_instruct_qa")
@lvwerra
Copy link
Member

lvwerra commented Oct 10, 2023

What about trying a machine with more VRAM? E.g. an A100 (80GB)? It should give you more room for training the model faster. Maybe @philschmid has some insights in what's different to the tutorial.

@younesbelkada
Copy link
Contributor

Hi @attkap

Since the introduction of #728 you need to manually set gradient_checkpointing=True in the TrainingArguments in order to run a more memory efficient training.
I believe this should resolve your issue and you'll manage to run training smoothly

@attkap
Copy link
Author

attkap commented Oct 27, 2023

Hi all, thanks for the responses. Indeed, I believe the issue was with setting gradient_checkpointing=True. Thanks for the help!

@attkap attkap closed this as completed Oct 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants