diff --git a/examples/scripts/sft_vlm_smol_vlm.py b/examples/scripts/sft_vlm_smol_vlm.py new file mode 100644 index 0000000000..2cac4f2cac --- /dev/null +++ b/examples/scripts/sft_vlm_smol_vlm.py @@ -0,0 +1,145 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +pip install pillow + +# Tested on 8x H100 GPUs +accelerate launch + --config_file=examples/accelerate_configs/deepspeed_zero3.yaml \ + sft_vlm_smol_vlm.py \ + --dataset_name HuggingFaceH4/llava-instruct-mix-vsft \ + --model_name_or_path HuggingFaceTB/SmolVLM-Instruct \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --output_dir sft-smol-vlm-hf \ + --bf16 \ + --torch_dtype bfloat16 \ + --gradient_checkpointing \ + --use_peft \ + --lora_target_modules down_proj, o_proj, k_proj, q_proj, gate_proj, up_proj, v_proj + +For LLaVA-NeXT, use: (requires transformers>=4.45) + --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf + +For meta-llama/Llama-3.2-11B-Vision-Instruct, use: (requires transformers>=4.45.1) + --model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct +""" + +import torch +from datasets import load_dataset +from transformers import ( + AutoModelForVision2Seq, + AutoProcessor, + Idefics3ForConditionalGeneration, + LlavaForConditionalGeneration, +) + +from trl import ( + ModelConfig, + ScriptArguments, + SFTConfig, + SFTTrainer, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) + + +if __name__ == "__main__": + parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) + script_args, training_args, model_config = parser.parse_args_and_config() + training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) + training_args.remove_unused_columns = False + training_args.dataset_kwargs = {"skip_prepare_dataset": True} + + ################ + # Model, Tokenizer & Processor + ################ + torch_dtype = ( + model_config.torch_dtype + if model_config.torch_dtype in ["auto", None] + else getattr(torch, model_config.torch_dtype) + ) + quantization_config = get_quantization_config(model_config) + model_kwargs = dict( + revision=model_config.model_revision, + attn_implementation=model_config.attn_implementation, + torch_dtype=torch_dtype, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + processor = AutoProcessor.from_pretrained( + model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code + ) + + model = AutoModelForVision2Seq.from_pretrained( + model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs + ) + + ################ + # Create a data collator to encode text and image pairs + ################ + def collate_fn(examples): + # Get the texts and images, and apply the chat template + texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples] + images = [example["images"] for example in examples] + if isinstance(model, LlavaForConditionalGeneration): + # LLava1.5 does not support multiple images + images = [image[0] for image in images] + + # Tokenize the texts and process the images + batch = processor(text=texts, images=images, return_tensors="pt", padding=True) + + # The labels are the input_ids, and we mask the padding tokens in the loss computation + labels = batch["input_ids"].clone() + labels[labels == processor.tokenizer.pad_token_id] = -100 # + # Ignore the image token index in the loss computation (model specific) + if isinstance(model, Idefics3ForConditionalGeneration): + image_token_id = processor.tokenizer.additional_special_tokens_ids[ + processor.tokenizer.additional_special_tokens.index("") + ] + else: + image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token) + labels[labels == image_token_id] = -100 + batch["labels"] = labels + + return batch + + ################ + # Dataset + ################ + dataset = load_dataset(script_args.dataset_name) + + ################ + # Training + ################ + trainer = SFTTrainer( + model=model, + args=training_args, + data_collator=collate_fn, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=processor.tokenizer, + peft_config=get_peft_config(model_config), + ) + + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + if trainer.accelerator.is_main_process: + processor.push_to_hub(training_args.hub_model_id)