You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've now finally got a training run to work with an A10 (24GB VRAM) and the script below.
Still, this is taking up a full 16GB of VRAM and running very slowly.
I'm using a per_device_train_batch_size of 2, with 4-bit quantised LoRA on reduced target modules and a max seq length of 2048. Can anyone explain to me why this setup is still so memory inefficient?
Any advice would be deeply appreciated🙏 - I've been banging my head against this for a full day.
Code to Reproduce
importtorchfromtrlimportSFTTrainerfromdatasetsimportload_datasetfromtransformersimportTrainingArguments, AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLMfrompeftimportAutoPeftModelForCausalLM, LoraConfig, get_peft_model, prepare_model_for_kbit_training# Loading the dataset dataset=load_dataset("databricks/databricks-dolly-15k", split="train")
# Filter QA pairs without contextdataset=dataset.filter(lambdax:x['context'] =='')
# A prompting formatting function defcreate_prompt_instruction(sample):
returnf"""### Instruction: Use the input below to create an instruction, which could have been used to generate the input using an LLM. ### Input {sample['response']} ### Response:{sample['instruction']} """# Import model and tokenizer model=AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", device_map='auto', load_in_4bit=True, use_cache=False)
tokenizer=AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
tokenizer.pad_token=tokenizer.eos_tokentokenizer.padding_side="right"# PEFT Config peft_config=LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
r=64,
bias="none",
task_type="CAUSAL_LM"
)
# Prepare the model for finetuning model=prepare_model_for_kbit_training(model)
model=get_peft_model(model, peft_config)
# Split dataset into 70% for training and 30% for testing dataset=dataset.train_test_split(test_size=0.3)
# Use the train split for training train_dataset=dataset['train']
# Define training arguments args=TrainingArguments(
output_dir="mistral_instruct_qa",
num_train_epochs=5,
per_device_train_batch_size=2,
warmup_steps=0.03,
logging_steps=10,
save_strategy="epoch",
learning_rate=2e-4,
bf16=True,
lr_scheduler_type='constant',
disable_tqdm=True
)
# Define SFTTrainer arguments max_seq_length=2048trainer=SFTTrainer(
model=model,
peft_config=peft_config,
max_seq_length=max_seq_length,
tokenizer=tokenizer,
packing=True,
formatting_func=create_prompt_instruction,
args=args,
train_dataset=train_dataset,
)
# Kick off the finetuning job trainer.train()
# Save finetuned model trainer.save_model("mistral_instruct_qa")
The text was updated successfully, but these errors were encountered:
What about trying a machine with more VRAM? E.g. an A100 (80GB)? It should give you more room for training the model faster. Maybe @philschmid has some insights in what's different to the tutorial.
Since the introduction of #728 you need to manually set gradient_checkpointing=True in the TrainingArguments in order to run a more memory efficient training.
I believe this should resolve your issue and you'll manage to run training smoothly
I've been trying to follow beginner's tutorials like:
https://www.philschmid.de/instruction-tune-llama-2
and
https://medium.com/@qendelai/fine-tuning-mistral-7b-instruct-model-in-colab-a-beginners-guide-0f7bebccf11c (and many others), but I keep getting OOM errors, despite optimising heavily for memory. What is weird is that when I follow the standard tutorials, I get OOM errors even with exactly the same setup (AWS EC2 g5.x4large) as in the tutorial.
I've now finally got a training run to work with an A10 (24GB VRAM) and the script below.
Still, this is taking up a full 16GB of VRAM and running very slowly.
I'm using a per_device_train_batch_size of 2, with 4-bit quantised LoRA on reduced target modules and a max seq length of 2048. Can anyone explain to me why this setup is still so memory inefficient?
Any advice would be deeply appreciated🙏 - I've been banging my head against this for a full day.
Code to Reproduce
The text was updated successfully, but these errors were encountered: