Vision finetuning
danielhanchen
released this
21 Nov 17:55
·
454 commits
to main
since this release
- We support Llama 3.2 Vision 11B, 90B; Pixtral; Qwen2VL 2B, 7B, 72B; and any Llava variants like Llava NeXT!
- We support 16bit LoRA or 4bit QLoRA. Both are accelerated and use much less memory!
- Llama 3.2 Vision finetuning - Radiography use case. Free Colab Kaggle Notebook
- Qwen 2 VL Vision finetuning - Maths OCR to LaTeX. Free Colab Kaggle Notebook
- Pixtral 12B Vision finetuning - General QA datasets. Free Colab
- Please run
pip install --upgrade --no-cache-dir unsloth unsloth_zoo
from unsloth import FastVisionModel # NEW instead of FastLanguageModel
import torch
model, tokenizer = FastVisionModel.from_pretrained(
"unsloth/Llama-3.2-11B-Vision-Instruct",
load_in_4bit = True, # Use 4bit quantization to reduce memory usage. Can be False.
use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
)
model = FastVisionModel.get_peft_model(
model,
finetune_vision_layers = True, # False if not finetuning vision part
finetune_language_layers = True, # False if not finetuning language part
finetune_attention_modules = True, # False if not finetuning attention layers
finetune_mlp_modules = True, # False if not finetuning MLP layers
r = 16, # The larger, the higher the accuracy, but might overfit
lora_alpha = 16, # Recommended alpha == r at least
lora_dropout = 0,
bias = "none",
random_state = 3407,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
# target_modules = "all-linear", # Optional now! Can specify a list if needed
)
from datasets import load_dataset
dataset = load_dataset("unsloth/llava-instruct-mix-vsft-mini", split = "train")
from unsloth import is_bf16_supported
from unsloth.trainer import UnslothVisionDataCollator
from trl import SFTTrainer, SFTConfig
FastVisionModel.for_training(model) # Enable for training!
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
data_collator = UnslothVisionDataCollator(model, tokenizer), # Must use!
train_dataset = dataset,
args = SFTConfig(
per_device_train_batch_size = 1, # Reduce to 1 to make Pixtral fit!
gradient_accumulation_steps = 4,
warmup_steps = 5,
max_steps = 30,
# num_train_epochs = 1, # Set this instead of max_steps for full training runs
learning_rate = 2e-4,
fp16 = not is_bf16_supported(),
bf16 = is_bf16_supported(),
logging_steps = 1,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "outputs",
report_to = "none", # For Weights and Biases
# You MUST put the below items for vision finetuning:
remove_unused_columns = False,
dataset_text_field = "",
dataset_kwargs = {"skip_prepare_dataset": True},
dataset_num_proc = 4,
max_seq_length = 2048,
),
)
trainer_stats = trainer.train()
After finetuning, you can also do inference:
FastVisionModel.for_inference(model) # Enable for inference!
image = dataset[2]["images"][0]
instruction = "Is there something interesting about this image?"
messages = [
{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": instruction}
]}
]
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt = True)
inputs = tokenizer(
image,
input_text,
add_special_tokens = False,
return_tensors = "pt",
).to("cuda")
from transformers import TextStreamer
text_streamer = TextStreamer(tokenizer, skip_prompt = True)
_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,
use_cache = True, temperature = 1.5, min_p = 0.1)
We also support merging QLoRA / LoRA directly into 16bit weights for serving:
# Select ONLY 1 to save! (Both not needed!)
# Save locally to 16bit
if False: model.save_pretrained_merged("unsloth_finetune", tokenizer,)
# To export and save to your Hugging Face account
if False: model.push_to_hub_merged("YOUR_USERNAME/unsloth_finetune", tokenizer, token = "PUT_HERE")
What's Changed
- Llama 3.2 by @danielhanchen in #1058
- Fix merges by @danielhanchen in #1079
- Handle absolute paths for save_to_gguf using pathlib by @giuliabaldini in #1120
- Only remove folder in sentencepiece check if it was created by @giuliabaldini in #1121
- Gradient Accumulation Fix by @danielhanchen in #1134
- Gradient Accumulation Fix by @danielhanchen in #1146
- fix: compute_loss bug by @vo1d-ai in #1151
- Windows installation guide in README by @timothelaborie in #1165
- chore: update chat_templates.py by @eltociear in #1166
- Many bug fixes by @danielhanchen in #1162
- Fix/patch tokenizer by @Erland366 in #1171
- Fix DPO, ORPO by @danielhanchen in #1177
- fix/transformers-unpack by @Erland366 in #1180
- Fix 4.47 issue by @danielhanchen in #1182
- 25% less mem and 10% faster training: Do not upcast lm_head and embedding to float32 by @Datta0 in #1186
- Cleanup upcast logs by @Datta0 in #1188
- Fix/phi-longrope by @Erland366 in #1193
- Bug fixes by @danielhanchen in #1195
- Fix/casting continue pretraining by @Erland366 in #1200
- Feat/all tmp by @danielhanchen in #1219
- Bug fixes by @danielhanchen in #1245
- Bug fix by @danielhanchen in #1249
- Bug fixes by @danielhanchen in #1255
- Fix: cast logits to float32 in cross_entropy_forward to prevent errors by @Erland366 in #1254
- Throw error when inferencing longer than max_popsition_embeddings by @Datta0 in #1236
- CLI now handles user input strings for dtype correctly by @Rabbidon in #1235
- Bug fixes by @danielhanchen in #1259
- Qwen 2.5 by @danielhanchen in #1280
- Fix/export mistral by @Erland366 in #1281
- DOC Update - Update README.md with os.environ in example by @udaygirish in #1269
- fix/get_chat_template by @Erland366 in #1246
- fix/sft-trainer by @Erland366 in #1276
- Bug fixes by @danielhanchen in #1288
- fix/sfttrainer-compatibility by @Erland366 in #1293
New Contributors
- @giuliabaldini made their first contribution in #1120
- @vo1d-ai made their first contribution in #1151
- @timothelaborie made their first contribution in #1165
- @eltociear made their first contribution in #1166
- @Erland366 made their first contribution in #1171
- @Datta0 made their first contribution in #1186
- @Rabbidon made their first contribution in #1235
- @udaygirish made their first contribution in #1269
Full Changelog: September-2024...November-2024