Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalizes VSFT script to support REDACTED #2120

Merged
merged 11 commits into from
Sep 25, 2024
29 changes: 19 additions & 10 deletions examples/scripts/vsft_llava.py → examples/scripts/vsft.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,33 @@
"""
pip install pillow

python examples/scripts/vsft_llava.py \
# Tested on 8x H100 GPUs
accelerate launch
--config_file=examples/accelerate_configs/deepspeed_zero3.yaml \
examples/scripts/vsft.py \
--dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
--model_name_or_path llava-hf/llava-1.5-7b-hf \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 8 \
--output_dir sft-llava-1.5-7b-hf \
--bf16 \
--torch_dtype bfloat16 \
--gradient_checkpointing \
--use_peft \
--dataloader_num_workers 32 \
--lora_target_modules=all-linear
--gradient_checkpointing

For LLaVA-NeXT, use: (requires transformers>=4.45)
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf

For REDACTED, use: (requires transformers>=4.45.1)
--model_name_or_path REDACTED

"""

from trl.commands.cli_utils import SFTScriptArguments, TrlParser

import torch
from accelerate import Accelerator
from datasets import load_dataset

from transformers import AutoModelForVision2Seq, AutoProcessor
from transformers import AutoModelForVision2Seq, AutoProcessor, LlavaForConditionalGeneration

from trl import (
ModelConfig,
Expand Down Expand Up @@ -88,14 +91,20 @@
def collate_fn(examples):
# Get the texts and images, and apply the chat template
texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
images = [example["images"][0] for example in examples]
images = [example["images"] for example in examples]
if isinstance(model, LlavaForConditionalGeneration):
# LLava1.5 does not support multiple images
images = [image[0] for image in images]

# Tokenize the texts and process the images
batch = processor(texts, images, return_tensors="pt", padding=True)
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == processor.tokenizer.pad_token_id] = -100 #
# Ignore the image token index in the loss computation (model specific)
image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
labels[labels == image_token_id] = -100
batch["labels"] = labels

return batch
Expand Down
Loading